#! /usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import socket
import Tac
import Acons
import PyClientBase
from collections import OrderedDict
from EDTAccess import traceMsg, TRACE_INFO

Value = Tac.Value

###################################################################################
# Execute commands using Pyserver in EOS agent.  This is intended to read
# or modify agent state from debug scripts.
###################################################################################

class AgentAccessorException( Exception ):
   pass

class AgentAccessor:
   def __init__( self, agentName, echoCmd=False ):
      self.agentName = agentName
      self.echoCmd = echoCmd
      self.sock = None
      self.remotePythonConsole = None
      self.kvSep = "|||"
      self.indent = '  '

   def __enter__( self ):
      # open agent remote console
      agent = Acons.findAgentByName( self.agentName )
      sockname = agent[ 'sockname' ]
      self.sock = sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM )
      sock.connect( sockname )
      sock.setblocking( True )
      execMode = PyClientBase.Rpc.execModeThreadPerConnection
      serviceIdChar = PyClientBase.Rpc.serviceIdCharShell
      sock.sendall( execMode + serviceIdChar )
      _ = sock.recv( 8192 )
      self.remotePythonConsole = Acons.RemotePythonConsole( sock, shell=True )
      return self

   def __exit__( self, exc_type, exc_value, exc_traceback ):
      # close agent remote console
      self.sock.close()
      self.sock = None
      self.remotePythonConsole = None

   def runCmd( self, cmd ):
      assert self.remotePythonConsole is not None
      if self.echoCmd:
         print( cmd )
      out = self.remotePythonConsole.getConsoleOutput( cmd )
      if out and "Exception raised" in out[ 0 ]:
         raise AgentAccessorException( '\n'.join( out ) )
      return out

   def setEntityPath( self, entityPath ):
      out = self.runCmd( "cd " + entityPath )
      if out and out[ 0 ].startswith( "Directory " ):
         raise AgentAccessorException( out[ 0 ] )

   def aconsFilterStmts( self, kvFilter, kvPrefix ):
      # Build Acons filter statements.
      def _makeCond( filtCrit ):
         if filtCrit.startswith( '_' ):
            return kvPrefix + filtCrit[ 1 : ] # replace '_' with kvPrefix
         else:
            return kvPrefix + '.' + filtCrit

      if isinstance( kvFilter, list ):
         kvCond = ' and '.join( [ _makeCond( x ) for x in kvFilter ] )
      else:
         kvCond = _makeCond( kvFilter )
      return [ self.indent + f"if not ( {kvCond} ):",
               2 * self.indent + "continue" ]

   def convertStrToVal( self, x ):
      if x.isdigit():
         x = int( x )
      elif x.startswith( "Value(" ):
         # pylint: disable-next=W0123
         x = eval( x )
      return x

   def splitLabelName( self, valStr ):
      comps = valStr.split( ':', 1 )
      if len( comps ) == 1:
         comps.append( comps[ 0 ] )
      return comps

   def convertKwArgsToFilter( self, kwargs, filterLabels, filterPrefix="" ):
      # Convert kwargs dict to list of filter criteria, e.g. [ fieldName=val ].
      fields = dict( [ self.splitLabelName( x ) for x in filterLabels ] )
      kvFilter = []
      for keyword, arg in kwargs.items():
         fieldName = fields.get( keyword )
         if fieldName is not None:
            assert arg is not None
            if isinstance( arg, str ):
               arg = "'" + arg + "'"
            kvFilter.append( f"{filterPrefix}{fieldName}=={arg}" )
      return kvFilter

   def getEntityCollSize( self, collName ):
      # Get number of entries in entity collection.
      out = self.runCmd( f"print(len(_.{collName}))" )
      assert out
      size = int( out[ 0 ] )
      return size

   def hasTruncatedCollData( self, out ):
      # Check if output is truncated.
      return out and out[ -1 ].endswith( "..." )

   def batchSizeForCollData( self, out ):
      # Determind the optimal batch size for output.
      totSize = 0
      maxLineSize = 0
      for line in out:
         size = len( line )
         totSize += size
         maxLineSize = max( size, maxLineSize )
      batchSize = totSize // maxLineSize
      assert batchSize > 0
      return batchSize

   def getEntityCollMap( self, entityPath, collName, keyFilter=None,
                         valFilter=None, initCmds=None, batchSize=None ):
      collMap = {}
      # Set path and run initial commands.
      self.setEntityPath( entityPath )
      if initCmds:
         self.runCmd( initCmds )
      # Build commands to retrieve collection entries (key + value).
      cmds = []
      if keyFilter:
         cmds.extend( self.aconsFilterStmts( keyFilter, 'k' ) )
      if valFilter:
         cmds.extend( self.aconsFilterStmts( valFilter, 'v' ) )
      cmds.append( self.indent + f"print( str(k) + '{self.kvSep}' + str(v) )" )
      cmds = '\n' + '\n'.join( cmds ) + '\n'

      # Upload collection items in batches due to PyServer max output size.
      # Batch size may be adjusted based on truncated output data.
      def _getCollItems( collMap, collName, cmds, start, batchSize ):
         found = False
         retry = False
         # Prepend command to iterate over entries.
         end = start + batchSize
         cmds = f"for k, v in list(_.{collName}.items())[{start}:{end}]:" + cmds
         # Execute commands and parse output.
         out = self.runCmd( cmds )
         if self.hasTruncatedCollData( out ):
            found = True
            retry = True
            batchSize = self.batchSizeForCollData( out )
            traceMsg( TRACE_INFO,
                      f"batchSize={batchSize} for {entityPath} {collName}" )
         else:
            for line in out:
               comps = line.split( self.kvSep )
               if len( comps ) == 2:
                  found = True
                  key = self.convertStrToVal( comps[ 0 ] )
                  if not isinstance( key, ( int, str ) ):
                     key = Tac.const( key )
                  val = self.convertStrToVal( comps[ 1 ] )
                  collMap[ key ] = val
         return ( found, retry, batchSize )

      # Get collection items until no more found.
      start = 0
      if batchSize is None:
         batchSize = self.getEntityCollSize( collName )
      while True:
         found, retry, batchSize = _getCollItems( collMap, collName, cmds,
                                                  start, batchSize )
         if not found:
            break
         if not retry:
            start += batchSize
      return collMap

   def getEntityCollList( self, entityPath, collName, kvFilter=None, isKey=False,
                          initCmds=None, batchSize=None ):
      collList = []
      # Set path and run initial commands.
      self.setEntityPath( entityPath )
      if initCmds:
         self.runCmd( initCmds )
      # Build commands to retrieve collection list (keys or values).
      listType = "keys" if isKey else "values"
      cmds = []
      if kvFilter:
         cmds.extend( self.aconsFilterStmts( kvFilter, 'x' ) )
      cmds.append( self.indent + f"print( '{self.kvSep}' + str( x ) )" )
      cmds = '\n' + '\n'.join( cmds ) + '\n'

      # Upload collection items in batches due to PyServer max output size.
      # Batch size may be adjusted based on truncated output data.
      def _getCollList( collList, collName, listType, cmds, start, batchSize ):
         found = False
         retry = False
         # Prepend command to iterate over entries.
         end = start + batchSize
         cmds = f"for x in list(_.{collName}.{listType}())[{start}:{end}]:" + cmds
         # Execute commands and parse output.
         out = self.runCmd( cmds )
         if self.hasTruncatedCollData( out ):
            found = True
            retry = True
            batchSize = self.batchSizeForCollData( out )
         else:
            for line in out:
               comps = line.split( self.kvSep )
               if len( comps ) == 2:
                  found = True
                  val = self.convertStrToVal( comps[ 1 ] )
                  collList.append( val )
         return ( found, retry, batchSize )

      # Get collection list entries until no more found.
      start = 0
      if batchSize is None:
         batchSize = self.getEntityCollSize( collName )
      while True:
         found, retry, batchSize = _getCollList( collList, collName, listType, cmds,
                                                 start, batchSize )
         if not found:
            break
         if not retry:
            start += batchSize
      return collList

   def getEntityCollKeys( self, entityPath, collName, keyFilter=None,
                          initCmds=None ):
      return self.getEntityCollList( entityPath, collName, kvFilter=keyFilter,
                                     isKey=True, initCmds=initCmds )

   def getEntityCollValues( self, entityPath, collName, valFilter=None,
                            initCmds=None ):
      return self.getEntityCollList( entityPath, collName, kvFilter=valFilter,
                                     isKey=False, initCmds=initCmds )

   def deleteEntityCollEntry( self, entityPath, collName, key=None, keyFilter=None,
                              valFilter=None, initCmds=None ):
      self.setEntityPath( entityPath )
      # Build commands to delete entries.
      cmds = []
      if initCmds:
         cmds.extend( initCmds )
      if key:
         cmds.append( f"del _.{collName}[{str(key)}]" )
      else:
         cmds.append( f"for k,v in _.{collName}.items():" )
         if keyFilter:
            cmds.extend( self.aconsFilterStmts( keyFilter, 'k' ) )
         if valFilter:
            cmds.extend( self.aconsFilterStmts( valFilter, 'v' ) )
         cmds.append( f"  del _.{collName}[ k ]" )
      cmds = '\n'.join( cmds ) + '\n'
      # Execute commands and parse output.
      out = self.runCmd( cmds )
      return out

   def updateEntityCollEntry( self, entityPath, collName, myMap, instColl=False ):
      # TODO
      pass

   def checkEntityAttrError( self, results, path, stopOnError=True ):
      # Check for errors from Acons commands.
      for line in results:
         if line.startswith( 'AttributeError:' ):
            if stopOnError:
               raise AgentAccessorException( line )
            del results[ : ]
            return

   def getSingleEntityAttrs( self, entityPath, names=None, prereq=None,
                             stopOnError=True ):
      # Set current entity path.
      self.setEntityPath( entityPath )
      # Execute prerequisite commands, if any.
      prereq = prereq or []
      for cmd in prereq:
         self.runCmd( cmd )
      # Generate commands to read entity attributes.
      cmds = []
      if names:
         attrLabelName = [ self.splitLabelName( name ) for name in names ]
         for label, name in attrLabelName:
            if "_." not in name:
               name = "_." + name
               cmds.append( f"print( '{label}: ', {name} )" )
      else:
         cmds.append( "ls -l" )

      # Execute Acons commands to read attributes from agent.
      results = []
      for cmd in cmds:
         results += self.runCmd( cmd )
      self.checkEntityAttrError( results, entityPath, stopOnError=stopOnError )

      # Build map of attribute label to value.  We use OrderedDict() here to
      # maintain same order as names.  This simplifies formatting of output.
      attrVals = OrderedDict()
      for line in results:
         label, val = line.split( ":", 1 )
         attrVals[ label.strip() ] = val.strip()
      return attrVals

   def setSingleEntityAttrs( self, entityPath, attrVals, prereq=None,
                             stopOnError=True ):
      # Set current entity path.
      self.setEntityPath( entityPath )
      # Execute prerequisite commands, if any.
      prereq = prereq or []
      for cmd in prereq:
         self.runCmd( cmd )
      # Generate commands to set entity attribute values.
      cmds = []
      for label, val in attrVals.items():
         cmds.append( f"_.{label} = {val}" )
      # Execute Acons commands to write attribute values to agent.
      results = []
      for cmd in cmds:
         results += self.runCmd( cmd )
      self.checkEntityAttrError( results, entityPath, stopOnError=stopOnError )
      return results

   def callEntityFunc( self, entityPath, funcName, funcArgs ):
      # TODO
      pass

   def fmtAttrValues( self, attrVals, brief=False, sortNames=False, useHex=False,
                      labelSep=": ", labelFmt=None, attrSep=None, attrValFmt=None,
                      prefix='' ):
      # Format attribute names with values, either brief form or long form.
      labels = attrVals.keys()
      if sortNames:
         labels = sorted( labels )
      if labelFmt is None:
         if brief:
            labelFmt = "%s"
         else:
            labelWidth = max( [ len( str( label ) ) for label in labels ] )
            labelFmt = f"%-{labelWidth}s"
      if attrSep is None:
         attrSep = " " if brief else "\n"

      def _fmtLabelVal( _label, _val ):
         labelStr = labelFmt % _label
         valStr = None
         if attrValFmt is not None:
            for lpat, fmtFunc in attrValFmt.items():
               if lpat.match( _label ):
                  valStr = fmtFunc( _val )
                  break
         if valStr is None:
            if useHex and isinstance( _val, int ):
               valStr = hex( _val )
            else:
               valStr = str( _val )
         return labelStr + labelSep + valStr

      return prefix + attrSep.join( [ _fmtLabelVal( label, attrVals[ label ] )
                                      for label in labels ] )

   def showSingleEntityAttrs( self, entityPath, names=None, prereq=None, brief=False,
                              labelSep=": ", prefix="", stopOnError=True ):
      # Read entity attribute values from remote agent and print them.
      attrVals = self.getSingleEntityAttrs( entityPath, names, prereq=prereq,
                                            stopOnError=stopOnError )
      if attrVals:
         print( prefix + self.fmtAttrValues( attrVals, brief=brief,
                                             labelSep=labelSep ) )
