# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from EDTAccess import ( DutAccessorBase, SshSession,
                        defaultRootUser, defaultRootPasswd )

class EDTPacketIO( DutAccessorBase ):
   def __init__( self ):
      super().__init__()
      self.mpcUsername = defaultRootUser
      self.mpcPassword = defaultRootPasswd
      self.dutUsername = defaultRootUser
      self.dutPassword = defaultRootPasswd
      self.defaultEthxmitTimeout = 2 * 24 * 60 * 60
      self.defaultTcpdumpTimeout = 30 * 24 * 60 * 60 # Number of seconds in 30 days

   def ethxmitOptions( self, dstMac=None, srcMac=None, dstIp=None, srcIp=None,
                       vlan=None, vpri=None, tpid=None, vlanInner=None,
                       vpriInner=None, tpidInner=None, gap=None,
                       count=0, cont=None ):
      # Format options for doEthxmit.
      xmitOpts = ""
      if dstMac:
         xmitOpts += f" -D {dstMac}"
      if srcMac:
         xmitOpts += f" -S {srcMac}"
      if vlan is not None:
         xmitOpts += f" --vlan {vlan}"
      if vpri is not None:
         xmitOpts += f" --vpri {vpri}"
      if tpid is not None:
         xmitOpts += f" --vlan-tpid {tpid}"

      if vlanInner is not None:
         xmitOpts += f" --inner-vlan {vlanInner}"
      if vpriInner is not None:
         xmitOpts += f" --inner-vpri {vpriInner}"
      if tpidInner is not None:
         xmitOpts += f" --inner-vlan-tpid {tpidInner}"

      if dstIp:
         xmitOpts += f" --ip-dst {dstIp}"
      if srcIp:
         xmitOpts += f" --ip-src {srcIp}"
      if gap:
         xmitOpts += f" --sleep {gap}"
      if count:
         xmitOpts += f" -n {count}"
      elif cont:
         xmitOpts += " -c"

      return xmitOpts

   def doEthxmitOnMpc( self, dutIntf, options=None, echo=False, timeout=None ):
      # Run ethxmit tool on MPC device/intf which is peer to dutIntf.
      mpcDev, mpcIntf = self.getDutMpcDevIntf( dutIntf )
      assert None not in ( mpcDev, mpcIntf ), \
            f"no MPC peer interface found for '{dutIntf}'"
      mpcCmd = f"sudo ethxmit {mpcIntf}"
      if options:
         mpcCmd += " " + options
      print( "using mpcDev:", mpcDev )
      print( mpcCmd )
      mpcSsh = SshSession( mpcDev, username=self.mpcUsername,
                         password=self.mpcPassword, shellPrompt=".*# " )
      out = mpcSsh.sendCmd( mpcCmd, echo=echo, timeout=timeout )
      mpcSsh.logout()
      if out:
         print( '\n'.join( out ) )

   def doEthxmit( self, dutIntf, options=None, echo=False, timeout=None,
                  useMpc=True ):
      # Run ethxmit tool to inject packet(s) on dutIntf.
      dutIntf = self.intfNameLong( dutIntf )
      if timeout is None:
         timeout = self.defaultEthxmitTimeout
      if useMpc:
         self.doEthxmitOnMpc( dutIntf, options=options, echo=echo, timeout=timeout )
      else:
         assert False, "ethxmit only supported on MPC"

   def mirrorDutIntfToCpu( self, dutIntf, direction, enable=True ):
      # Remove monitor session for dut interface.
      assert direction in [ "rx", "tx", "both" ]
      sessionName = self.intfNameLinux( dutIntf ) + "_" + direction
      self.dutCli.configCmd( f"no monitor session {sessionName}", ignoreErrors=True )
      mirrorDevName = None
      if enable:
         # Setup new monitor session with dest=cpu.
         self.dutCli.configCmd( f"monitor session {sessionName} source"
                                f" { self.intfNameLong( dutIntf ) } { direction }" )
         self.dutCli.configCmd( f"monitor session {sessionName} dest cpu" )
         # Wait for mirror dev name to be available.

         def hasMirrorDeviceName():
            out = self.dutCli.showCmd( f"show monitor session {sessionName}",
                                       jsonOutput=True )
            nonlocal mirrorDevName
            sessions = out.get( "sessions" )
            if sessions is not None:
               mySession = sessions.get( sessionName )
               if mySession is not None:
                  mirrorDevName = mySession.get( "mirrorDeviceName" )
            return mirrorDevName is not None
         self.waitFor( hasMirrorDeviceName, timeout=5, minDelay=0.2 )
      return mirrorDevName

   def doTcpdumpOnMpc( self, dutIntf, options=None, echo=False, timeout=None ):
      # Run tcpdump on the MPC interface which is peer to dutIntf.
      mpcDev, mpcIntf = self.getDutMpcDevIntf( dutIntf )
      assert None not in ( mpcDev, mpcIntf ), \
            f"no MPC peer interface found for '{dutIntf}'"
      print( f"using mpc {mpcDev}" )
      cmd = f"sudo tcpdump -i {mpcIntf}"
      if options:
         cmd += " " + options
      print( cmd )
      mpcSsh = SshSession( mpcDev, username=self.mpcUsername,
                           password=self.mpcPassword,
                           shellPrompt=".*# ", trace=True, timeout=timeout )
      mpcSsh.sendCmd( cmd, echo=echo )
      mpcSsh.logout()

   def doTcpdumpOnSup( self, dutIntf, options=None, echo=False, timeout=None,
                       direction="both" ):
      # Setup monitor session which will mirror dut intf to cpu first.
      mirrorName = self.mirrorDutIntfToCpu( dutIntf, direction, enable=True )
      assert mirrorName, ( f"failed to setup monitor session for {dutIntf} "
                           f"direction {direction}" )
      # Run tcpdump on supervisor netdev mapped to dutIntf.
      cmd = f"sudo tcpdump -i {mirrorName}"
      if options:
         cmd += " " + options
      dutSsh = SshSession( self.dutName, username=self.dutUsername,
                           password=self.dutPassword, shellPrompt=".*# ",
                           trace=True, timeout=timeout )
      dutSsh.sendCmd( cmd, echo=echo )
      dutSsh.logout()
      # Remove monitor session for dut intf.
      self.mirrorDutIntfToCpu( dutIntf, direction, enable=False )

   def doTcpdump( self, dutIntf, options=None, echo=False,
                  timeout=None, direction="tx", useMpc=True ):
      # Run tcpdump tool to capture packets on dut interface.
      dutIntf = self.intfNameLong( dutIntf )
      if timeout is None:
         timeout = self.defaultTcpdumpTimeout
      if useMpc:
         return self.doTcpdumpOnMpc( dutIntf, options=options, echo=echo,
                                     timeout=timeout )
      else:
         return self.doTcpdumpOnSup( dutIntf, options=options, echo=echo,
                                     timeout=timeout, direction=direction )
