#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import datetime
import os
import binascii
import socket
import time
import operator
import re
import sys
import json
from DeviceNameLib import eosIntfToKernelIntf

def log( msg ):
   print( str( datetime.datetime.now() ), msg )
   sys.stdout.flush()

# pylint: disable-msg=pointless-string-statement
'''Expected JSON Config file format:
{
      "Packets": {
         "Packet1": {
            "content": "00aabbccddeeff",
            "interfaces": [ "Ethernet2/1", "Ethernet3/4" ],
            "intervalInMs": "100"
            },
         "Packet2": {
            "content": "00aabbccddeeff",
            "interfaces": [ "Ethernet2/1", "Ethernet3/4" ],
            "intervalInMs": "100"
            },
         "Packet3": {
            "content": "00aabbccddeeff",
            "interfaces": [ "Ethernet2/1", "Ethernet3/4" ],
            "intervalInMs": "100"
            }
         }
}
'''

class ConfigParserException( Exception ):
   def __init__( self, msg ):
      Exception.__init__( self )
      self.msg = msg

   def __str__( self ):
      return repr( self.msg )

class ConfigParser:
   '''Responsible for reading in JSON config files and converting the information
   into format usable by AsuFastPktTransmit without further manipulation (aside
   from opening sockets). Tasks include format validation, converting content
   from a string to binary, and converting intervalInMs from a string to an
   integer representing intervalInTicks.
   '''
   PACKETS    = "Packets"
   PKTNAME    = "pktName"
   CONTENT    = "content"
   INTERFACES = "interfaces"
   MS_IVAL    = "intervalInMs"

   def __init__( self, raiseErrors=False ):
      self.asuCfgByPath_ = {}
      self.raiseErrors_ = raiseErrors

   def _readJsonFiles( self, cfgDir ):
      '''Imports all JSON config files from cfgDir and puts their contents into
      a dictionary with the full path of the config file as the key. Logs and skips
      any files that have a JSON syntax error. Does not do any massaging of the
      packets described in the config files.
      '''
      jsonConfigs = {}
      # let OSError from listdir bubble up
      cfgFiles = os.listdir( cfgDir )
      for cfgFile in cfgFiles:
         cfgPath = os.path.join( cfgDir, cfgFile )
         if os.path.isdir( cfgPath ):
            log( "Encountered unexpected subdirectory %s - ignoring." %
                 ( cfgPath ) )
         log( "Parsing config file %s" % ( cfgPath ) )
         try:
            with open( cfgPath ) as f:
               jsonConfig = json.load( f )
         except json.JSONDecodeError:
            log( "File %s not valid JSON format. Skipping." % ( cfgPath ) )
            if self.raiseErrors_:
               raise
         except OSError as e:
            log( f"Could not read {cfgPath}: {e}. Skipping." )
            if self.raiseErrors_:
               raise
         else:
            jsonConfigs[ cfgPath ] = jsonConfig
      return jsonConfigs

   def _createCfgFromJson( self, jsonConfigs ):
      '''Creates packets dictionary from configs that were read in by
      readJsonFiles(). Each config file results in a list of packets that
      is inserted in top level dictionary, with the full path of the config
      file used as the key. Basic validation of each packet is done, and
      values are converted into form used by AsuFastPktTransmit: content is
      converted from a hex string into a binary string ready to be transmitted
      without further manipulation, interfaces is checked to be a list, and
      intervalInMs is translated into an integer  and converted to intervalInTicks.

      Any packets that have one or missing attributes or whose attribute(s)
      are invalid are not included. If all packets are invalid nothing is entered
      into the dictionary for that config file.
      '''
      self.asuCfgByPath_ = {}
      for cfgPath, jsonConfig in jsonConfigs.items():
         asuPackets = []
         try:
            packets = jsonConfig[ ConfigParser.PACKETS ]
            for pktName, packet in packets.items():
               try:
                  content = binascii.unhexlify( packet[ ConfigParser.CONTENT ] )
                  ivalInMs = int( packet[ ConfigParser.MS_IVAL ] )
                  eosIntfs = packet[ ConfigParser.INTERFACES ]
               except KeyError as e:
                  log( "%s: packet '%s': missing attribute '%s'. Skipping." %
                       ( cfgPath, pktName, e ) )
                  if self.raiseErrors_:
                     raise
               except TypeError:
                  log( "%s: packet '%s': '%s' not a hex string. Skipping." %
                       ( cfgPath, pktName, ConfigParser.CONTENT ) )
                  if self.raiseErrors_:
                     raise
               except ValueError:
                  log( "%s: packet '%s': '%s' not an integer. Skipping." %
                       ( cfgPath, pktName, ConfigParser.MS_IVAL ) )
                  if self.raiseErrors_:
                     raise
               except binascii.Error as e: # pylint: disable=bad-except-order
                  log( "%s: packet '%s': %s on unhexlify(%s). Skipping." %
                       ( cfgPath, pktName, ConfigParser.CONTENT, e ) )
                  if self.raiseErrors_:
                     raise
               else:
                  if isinstance( eosIntfs, ( list ) ):
                     asuPacket = { ConfigParser.PKTNAME: pktName,
                                   ConfigParser.CONTENT: content,
                                   ConfigParser.MS_IVAL: ivalInMs,
                                   ConfigParser.INTERFACES: eosIntfs }
                     asuPackets.append( asuPacket )
                  else:
                     log( "%s: packet '%s': '%s' is not a list of EOS interfaces" %
                          ( cfgPath, pktName, ConfigParser.INTERFACES ) )
                     if self.raiseErrors_:
                        raise ConfigParserException( "%s pkt %s %s not valid" %
                              ( cfgPath, pktName, ConfigParser.INTERFACES ) )
         except KeyError as e:
            log( f"{cfgPath}: missing '{e}'. Skipping." )
            if self.raiseErrors_:
               raise
         except AttributeError:
            log( "%s: '%s' not a dictionary. Skipping." %
                 ( cfgPath, ConfigParser.PACKETS ) )
            if self.raiseErrors_:
               raise
         else:
            if asuPackets:
               self.asuCfgByPath_[ cfgPath ] = asuPackets
            else:
               log( "%s: no valid packets in cfg file. Skipping." % ( cfgPath ) )
               if self.raiseErrors_:
                  raise ConfigParserException( "%s: no valid packets in cfg file" %
                        ( cfgPath ) )
         
   def read( self, cfgDir ):
      jsonConfigs = self._readJsonFiles( cfgDir )
      self._createCfgFromJson( jsonConfigs )
      return self.asuCfgByPath_

class RawSockFactory:
   socks_ = {}

   @classmethod
   def getSock( cls, eosIntf ):
      if eosIntf not in RawSockFactory.socks_:
         krnIntf = eosIntfToKernelIntf( eosIntf )
         sock = socket.socket( socket.AF_PACKET, socket.SOCK_RAW, 0 )
         sock.bind( ( krnIntf, 0 ) )
         RawSockFactory.socks_[ eosIntf ] = sock
      return RawSockFactory.socks_[ eosIntf ]

   @classmethod
   def shutdownSock( cls, eosIntf ):
      '''
      Test only method - used to validate operation of AsuFastPktTransmit
      when the socket disappears.
      '''
      if eosIntf in RawSockFactory.socks_:
         RawSockFactory.socks_[ eosIntf ].shutdown( socket.SHUT_RDWR )

   @classmethod
   def closeAllSocks( cls ):
      for intf, sock in iter( RawSockFactory.socks_.items() ): 
         try:
            sock.close()
         except OSError as e:
            log( f"Unable to close sock {intf}: {e}" )
      RawSockFactory.socks_ = {}

   def __init__( self ):
      pass

class AsuPacket:
   def __init__( self, cfgFileName, pktName, content, interfaces, intervalInMs,
                 startTime ):
      self.cfgFileName = cfgFileName
      self.pktName_ = pktName
      self.socks_ = {}
      self.content_ = content
      self.interfaces_ = interfaces
      self.interval_ = datetime.timedelta( milliseconds=intervalInMs )
      # Placeholder to start - so sort works
      self.nextSendTime = startTime
      self.firstSend_ = True
      for eosIntf in interfaces:
         try:
            rawSock = RawSockFactory().getSock( eosIntf )
            self.socks_[ eosIntf ] = rawSock
         except OSError as e:
            log( "%s packet '%s': could not create socket for iface '%s': %s" %
                 ( cfgFileName, pktName, eosIntf, e ) )

   def oneShot( self ):
      return not self.interval_

   def send( self, traceMsg ):
      if self.firstSend_:
         self.nextSendTime = datetime.datetime.now()
         self.firstSend_ = False

      badIntfs = []
      for eosIntf, sock in iter( self.socks_.items() ):
         try:
            sock.sendall( self.content_ )
            if traceMsg:
               log( "Sent cfgFile %s packet %s on %s" %
                    ( self.cfgFileName, self.pktName_, eosIntf ) )
         except OSError as e:
            log( "%s packet '%s': failed to send on %s: %s. Dropping iface" %
                 ( self.cfgFileName, self.pktName_, eosIntf, e ) )
            badIntfs.append( eosIntf )
      for badIntf in badIntfs:
         del self.socks_[ badIntf ]

      self.nextSendTime += self.interval_

   def numIntfs( self ):
      return len( self.socks_ )

class AsuFastPktTransmit:
   '''
   Responsible for sending out defined packets at predefined intervals, based
   on information in config files, while agents normally responsible for these tasks
   are restarting. Goes to sleep between intervals. At each run interval, if the
   associated config file is missing from the config directory the service prunes
   the affected packets. If there are no more packets the service exits. The service
   also exits if predetermined runtime has been exhausted - to guard against agents
   that don't start or forget to remove their config files.

   Refer to design document AID4656 for more information.
   '''
   def __init__( self, cfgDir, backupDir, maxRunTimeInS, tracePkts=False ):
      '''
      Parameters
      ----------
      cfgDir : Directory where AsuFastPktTransmit should search for packet config
         files.
      backupDir : Directory where AsuFastPktTransmit should move any config files
         after maxRunTimeS that were not already moved by their associated agent.
      maxRunTimeInS : Maximum duration in seconds which AsuFastPktTransmit.run() is
         allowed to execute. Note that this is upper bound. If all services making
         use of AsuFastPktTransmit take ownership by removing their associated
         configuration files AsuFastPktTransmit will exit early.
      tracePkts : Specified to True if want traces of packets stored in log.
         Defaults to False.

      Example Usage:
      asuFastPktTransmit = AsuFastPktTransmit( "/mnt/flash/fastpkttx",
                                               "/mnt/flash/fastpkttx.backup",
                                               300 )

      Corresponds to the configuration information listed in the ASU Fast 'Packet
      Restore' design document of config files found in /mnt/flash/fastpkttx, backup
      files copied to /mnt/flash/fastpkttx.backup, run time of 5 minutes
      (5 * 60 = 300).
      '''
      log( "Config path: %s Max Run Time: %d Seconds" % ( cfgDir, maxRunTimeInS ) )
      self.cfgDir_ = cfgDir
      self.backupDir_ = backupDir
      self.maxRunTimeInS_ = maxRunTimeInS
      self.asuPacketsByPath_ = {}
      self.oneShotAsuPacketsByPath_ = {}
      self.tracePkts_ = tracePkts

      # Tracks next packet to be sent for each config file
      # Used to determine how long to sleep before sending next
      # packet
      self.nextPacketByCfg_ = []

   def _sortPacketList( self, packets ):
      packets.sort( key=operator.attrgetter( "nextSendTime" ) )

   def _createPacketsFromCfg( self, asuConfig, startTime ):
      for cfgFileName, pktConfigs in asuConfig.items():
         asuPackets = []
         oneShotPackets = []
         for pktConfig in pktConfigs:
            asuPacket = AsuPacket( cfgFileName, pktConfig[ ConfigParser.PKTNAME ],
                                   pktConfig[ ConfigParser.CONTENT ],
                                   pktConfig[ ConfigParser.INTERFACES ],
                                   pktConfig[ ConfigParser.MS_IVAL ], startTime )
            if asuPacket.numIntfs():
               if asuPacket.oneShot():
                  oneShotPackets.append( asuPacket )
               else:
                  asuPackets.append( asuPacket )
            else:
               log( "%s: packet %s has no valid interfaces. Skipping." %
                    ( cfgFileName, pktConfig[ ConfigParser.PKTNAME ] ) )
         if not asuPackets and not oneShotPackets:
            log( "%s: no valid packets. Skipping config file entirely." %
                 ( cfgFileName ) )
         else:
            if asuPackets:
               self._sortPacketList( asuPackets )
               self.asuPacketsByPath_[ cfgFileName ] = asuPackets
            if oneShotPackets:
               self.oneShotAsuPacketsByPath_[ cfgFileName ] = oneShotPackets

      for asuPackets in self.asuPacketsByPath_.values():
         # self.nextPacketByCfg_ only tracks next packet sent for each config file
         self.nextPacketByCfg_.append( asuPackets[ 0 ] )
      self._sortPacketList( self.nextPacketByCfg_ )

   def _sendOneShotPackets( self ):
      for cfgFileName, asuPackets in iter( self.oneShotAsuPacketsByPath_.items() ):
         if os.path.exists( cfgFileName ):
            for asuPacket in asuPackets:
               asuPacket.send( self.tracePkts_ )
         else:
            log( "Config file %s removed. Skipping associated one shot packets" %
                 ( cfgFileName ) )
      self.oneShotAsuPacketsByPath_ = {}

   def _sendPackets( self, currTime ):
      packetsToPop = 0
      pktsToAppend = []
      for nextPacket in self.nextPacketByCfg_:
         if not os.path.exists( nextPacket.cfgFileName ):
            log( "Config file %s removed. Discontinuing transmission." %
                 ( nextPacket.cfgFileName ) )
            packetsToPop += 1
            del self.asuPacketsByPath_[ nextPacket.cfgFileName ]
         elif nextPacket.nextSendTime <= currTime:
            packetsToPop += 1
            packets = self.asuPacketsByPath_[ nextPacket.cfgFileName ]
            for packet in packets:
               if packet.nextSendTime <= currTime:
                  packet.send( self.tracePkts_ )
               else:
                  break
            self._sortPacketList( packets )
            pktsToAppend.append( packets [ 0 ] )
         else:
            break
      del self.nextPacketByCfg_[ :packetsToPop ]
      self.nextPacketByCfg_.extend( pktsToAppend )
      self._sortPacketList( self.nextPacketByCfg_ )

   def _asuReboot( self ):
      cmdFile = open( "/proc/cmdline" ) # pylint: disable=consider-using-with
      m = re.search( r"arista\.asu_(reboot|hitless)", cmdFile.read() )
      return m is not None

   def _moveCfgFilesToBackup( self, suffix="", warning="" ):
      cfgFiles = os.listdir( self.cfgDir_ )
      for cfgFile in cfgFiles:
         cfgPath = os.path.join( self.cfgDir_, cfgFile )
         if os.path.isdir( cfgPath ):
            # silently ignore subdirectories - already squawked about them during
            # startup
            continue

         log( f"Agent associated with {cfgPath} {warning}." )
         dstPath = os.path.join( self.backupDir_, cfgFile + suffix )
         os.rename( cfgPath, dstPath )

   def _moveTimedOutCfgFiles( self ):
      self._moveCfgFilesToBackup( suffix=".timedout",
                                  warning="did not move file before max run time" )

   def _coldBootFlushCfgFiles( self ):
      self._moveCfgFilesToBackup( suffix=".stale.cold",
                                  warning="left stale config file for cold boot" )

   def run( self, unitTest=False, simulateUncaughtExc="False" ):
      # If simulateUncaughtExc is True, raise an exception. Used to verify calling
      # script catches exceptions not dealt with by AFPT
      if simulateUncaughtExc == "True":
         raise ConfigParserException( "This is a test" )

      if not os.path.exists( self.cfgDir_ ):
         log( "WARNING: %s does not exist. Exiting." % ( self.cfgDir_ ) )
         return

      if not os.path.exists( self.backupDir_ ):
         log( "WARNING: %s does not exist. Exiting." % ( self.backupDir_ ) )
         return

      # Don't run if not ASU reboot - don't depend on lack of presence of
      # Agent packet files as marker of whether AsuFastPktTransmit should
      # be running.
      if not unitTest and not self._asuReboot():
         log( "Not an ASU reboot. Exiting." )
         self._coldBootFlushCfgFiles()
         return

      MICROSECONDS_PER_SECOND = 1000000.0
      cfgParser = ConfigParser()
      log( "Parsing config files" )
      try:
         asuPktCfgByPath = cfgParser.read( self.cfgDir_ )
      except OSError as e:
         log( f"ERROR: unable to access {self.cfgDir_}: {e}. Exiting." )
         exit( 1 ) # pylint: disable=consider-using-sys-exit

      startTime = datetime.datetime.now()
      stopTime = startTime + datetime.timedelta( seconds=self.maxRunTimeInS_ )

      log( "Creating and sorting packets from config" )
      self._createPacketsFromCfg( asuPktCfgByPath, startTime )

      log( "Sending one-shot packets" )
      self._sendOneShotPackets()

      currTime = datetime.datetime.now()
      while currTime < stopTime:
         self._sendPackets( currTime )
         if not self.asuPacketsByPath_:
            log( "No more packets to send. Exiting" )
            break

         nextPacket = self.nextPacketByCfg_[ 0 ]
         if nextPacket.nextSendTime >= stopTime:
            log( "No more packets to send before max run time. Exiting" )
            break

         currTime = datetime.datetime.now()
         while currTime < nextPacket.nextSendTime:
            sleepDelta = nextPacket.nextSendTime - currTime
            secondsToSleep = sleepDelta.seconds
            secondsToSleep += ( sleepDelta.microseconds / MICROSECONDS_PER_SECOND )
            time.sleep( secondsToSleep )
            currTime = datetime.datetime.now()

      # Explicitly shutdown and close all sockets before exiting
      log( "Shutting down all open sockets" )
      RawSockFactory.closeAllSocks()

      self._moveTimedOutCfgFiles()

      log( "Done" )
