# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import subprocess
import socket
import re
from string import Template
import struct
import datetime
import time
import traceback

class FileLogger:
   def __init__( self, logfile ):
      self.logfile_ = logfile

   def log( self, message ):
      with open( self.logfile_, 'a' ) as f:
         f.write( datetime.datetime.now().strftime( '%Y-%m-%d %H:%M:%S.%f ' ) +
                  ': ' + message + '\n' )

class GratuitousNeighborAdv:
   def __init__( self, logger_ ):
      self.ncache_ = []
      self.intfDict_ = {}
      self.logger_ = logger_
      self.txPktList_ = []

   def intfInfoIs( self ):
      ipRe = r'\d+: (?P<intf>\S+)\s+inet6 (?P<ip>\S+)/(?P<plen>\d+)'
      ipRe += r'\s+scope (?P<scope>\S+)'

      # get all intfs that have ip addresses configured
      proc = subprocess.Popen( [ 'ip', '-o', '-6', 'addr', 'show' ],
                               stdout=subprocess.PIPE )
      out = proc.communicate()[ 0 ]
      out = out.split( '\n' )
      # remove last entry as it's an empty string
      del out[ -1 ]

      llAddr = {}
      for line in out:
         m1 = re.match( ipRe, line )
         if m1 is None:
            self.log( f"Invalid output in ip -6 command: {line}" )
            continue
         intf = m1.group( 'intf' )
         ip = m1.group( 'ip' )
         ip = self.fullV6AddrIs( ip )
         scope = m1.group( 'scope' )
         if intf in [ 'lo', 'ma1' ]:
            continue
         if intf == 'macvlan':
            intf = 'macvlan-bond0'
         if scope == 'link':
            llAddr[ intf ] = ip
            continue
         try:
            _ = ip.decode( 'hex' )
         except TypeError:
            self.log( f"invalid intf entry {line}" )
            continue
         self.intfDict_[ intf ] = [ ip, m1.group( 'plen' ) ]

      # get all mac addresses
      intfRe = r'\d+: (?P<intf>\S+).*link\/ether (?P<mac>\S+)'
      for intf in self.intfDict_:
         proc = subprocess.Popen( [ 'ip', '-o', 'link', 'show', intf ],
                                  stdout=subprocess.PIPE )
         out = proc.communicate()[ 0 ]
         m1 = re.match( intfRe, out )
         if m1 is None:
            self.log( f"Invalid output in ip link command: {out}" )
            continue
         mac = m1.group( 'mac' )
         macHex = "".join( mac.split( ":" ) )
         self.intfDict_[ intf ] = { 'ip' : self.intfDict_[ intf ][ 0 ],
                                    'plen' : self.intfDict_[ intf ][ 1 ],
                                    'llAddr' : llAddr[ intf ] if intf in llAddr
                                               else None,
                                    'mac' : mac, 'macHex' : macHex }

   def fullV6AddrIs( self, addr ):
      tmp = socket.inet_pton(socket.AF_INET6, addr )
      tmp = bytearray( tmp )
      return "".join( [ "%02x" %i for i in tmp ] )

   def neighborCacheIs( self ):
      # get all nbr entries in the cache
      cmd = [ 'ip', '-6', 'neigh', 'show' ]
      proc = subprocess.Popen( cmd, stdout=subprocess.PIPE )
      out = proc.communicate()[ 0 ]
      out = out.split( '\n' )
      # remove last entry as it's an empty string
      del out[ -1 ]
      out = [ l.split( " " ) for l in out ]
      for nbr in out:
         ip = self.fullV6AddrIs( nbr[ 0 ] )
         intf = nbr[ 2 ]
         if intf in [ 'lo', 'ma1' ]:
            continue
         hostMac = nbr[ 4 ]
         hostMacHex = "".join( hostMac.split( ":" ) )
         # validate the ip and mac addresses
         try:
            _ = ip.decode( 'hex' )
            _ = hostMacHex.decode( 'hex' )
         except TypeError:
            self.log( "invalid nbr entry %s" % nbr )
            continue
         self.ncache_.append( { 'intf' : intf,
                                'ip' : nbr[ 0 ],
                                'fullip' : ip,
                                'mac' : hostMac,
                                'macHex' : hostMacHex } )

   def neighborAdvertisementsIs( self ):

      # V6 hdr: ver(4b)6, TC(8b)0, flow label(20b)0, len(16b)20, next hdr(8b)3a,
      # hop lim(8b)255, srcaddr(128b), dstaddr(128b)
      v6Template = Template( "6000000000203aff$saddr$daddr" )
      # Icmp hdr: type(8b) 88, code(8b) 00, cksum(16b), Router(1b), Solicited(1b),
      # Override(1b), Reserved(29b), tgt addr(128b), tgt link layer(8b) 02,
      # length(8b) 01, eth mac (48b)
      icmpTemplate = Template( "8800${csum}c0000000${tgtaddr}0201$smac" )

      for host in self.ncache_:
         intf = host[ 'intf' ]
         if intf not in self.intfDict_:
            self.log( "Interface info not found %s" % intf )
            continue

         dipHex = host[ 'fullip' ]
         if dipHex.startswith( "fe80" ):
            if not self.intfDict_[ intf ][ 'llAddr' ]:
               self.log( "llAddr not found %s" % intf )
               continue
            sip = self.intfDict_[ intf ][ 'llAddr' ]
         else:
            sip = self.intfDict_[ intf ][ 'ip' ]

         smac = self.intfDict_[ intf ][ 'mac' ]
         smacHex = self.intfDict_[ intf ][ 'macHex' ]
         v6Hdr = v6Template.substitute( saddr=sip, daddr=dipHex )
         icmpHdr = icmpTemplate.substitute( csum='0000', tgtaddr=sip,
                                            smac=smacHex )

         def _icmpV6Checksum():
            # pylint: disable-msg=cell-var-from-loop
            sipBytes = bytearray.fromhex( sip )
            dipBytes = bytearray.fromhex( dipHex )
            nextHdr = struct.pack( ">I", 0x3a )
            payloadLen = struct.pack( ">I", 0x20 )
            icmpBytes = bytearray.fromhex( icmpHdr[ 0: ] )
            packet = sipBytes + dipBytes + nextHdr + payloadLen + icmpBytes

            # iterate over psuedopkt to add every byte. Since it's 16-bit math,
            # take care of carry overs
            total = 0
            num_words = len(packet) // 2
            for chunk in struct.unpack("!%sH" % num_words, packet[0:num_words*2]):
               total += chunk
               total = (total >> 16) + (total & 0xffff)
            
            # Add any left over byte
            if len(packet) % 2:
               total += ord(packet[-1]) << 8
            total += total >> 16
            return ~total & 0xffff

         csum = "%04x" % _icmpV6Checksum()
         icmpHdr = icmpTemplate.substitute( csum=csum, tgtaddr=sip, smac=smacHex )

         eth = host[ 'macHex' ] + smacHex + '86dd'
         self.txPktList_.append( { 'dmac' : host[ 'mac' ],
                              'smac' : smac,
                              'intf' : intf,
                              'ip' : host[ 'ip' ],
                              'rawPkt' : v6Hdr + icmpHdr,
                              'eth' : eth,
                            } )
      self.log( "Found %d neighbors" % len( self.txPktList_ ) )

   def sendNAPkts( self ):
      if not self.txPktList_:
         self.log( "No packets to send" )

      s = socket.socket( socket.AF_PACKET, socket.SOCK_RAW )
      for pktData in self.txPktList_:
         intf = pktData[ 'intf' ]
         s.bind( ( intf, 0 ) )
         pkt = pktData[ 'eth' ] + pktData[ 'rawPkt' ]
         try:
            pkt = pkt.decode( 'hex' )
         except TypeError:
            self.log( "Error in packet %s" % pkt )
            continue
         _ = s.send( pkt )
      s.close()

   def doNeighborAdvertisements( self ):
      self.log( "=== Running v6 ND processing" )
      self.intfInfoIs()
      # get all the known v6 neighbors
      self.neighborCacheIs()
      self.neighborAdvertisementsIs()

      frequency = 15
      now = time.time()
      end = now + 180  # run script for 3 minutes

      i = 1
      while now < end:
         self.sendNAPkts()
         t = time.time() - now
         self.log( f"Finished iteration {i}" )
         t = min( t, frequency )
         time.sleep( frequency - t )
         now = time.time()
         i += 1

      self.log( "=== Completed v6 ND processing" )

   def log( self, message ):
      self.logger_.log( message )

if __name__ == "__main__":
   logger = FileLogger( "/mnt/flash/asu-grat-na.log" )
   logger.log( "Running upgrade policy" )
   try:
      nbrProc = GratuitousNeighborAdv( logger )
      nbrProc.doNeighborAdvertisements()
   except Exception as e:         # pylint: disable-msg=W0703
      logger.log( "Exception:\n%s" % str( e ) )
      logger.log( traceback.format_exc() )
