# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AaaPluginLib
from AaaPluginLib import TR_INFO, TR_DEBUG
from Tracing import traceX
import Tac
import Cell
import Toggles.GnsiToggleLib as GnsiToggle

import threading

gnsiConfigReactor_ = None
gnsiAcctzMethodName = "GnsiAcctz"
DEFAULT_PRIVILEGE_LEVEL = 0
localHost = "localhost"
loopbackIp = "127.1.1.1"

authType = Tac.Type( "Gnsi::Acctz::AuthType" )
gRPCType = Tac.Type( "Gnsi::Acctz::GrpcServiceType" )
authenSourceEnum = Tac.Type( "Aaa::AuthenSource" )

# AuthenSource to gNSI.authenType conversion map
authTypeMap = {
      Tac.enumValue( "Aaa::AuthenSource",
         authenSourceEnum.authenSourcePassword ): authType.AUTHN_TYPE_PASSWORD,
      Tac.enumValue( "Aaa::AuthenSource",
         authenSourceEnum.authenSourceSshkey ): authType.AUTHN_TYPE_SSHKEY,
      Tac.enumValue( "Aaa::AuthenSource",
         authenSourceEnum.authenSourceSshcert ): authType.AUTHN_TYPE_SSHCERT,
      Tac.enumValue( "Aaa::AuthenSource",
         authenSourceEnum.authenSourceTlscert ): authType.AUTHN_TYPE_TLSCERT,
      }

# gRPC cpnversion map
gRPCTypeMap = {
      "gnmi": gRPCType.GRPC_SERVICE_TYPE_GNMI,
      "gnoi": gRPCType.GRPC_SERVICE_TYPE_GNOI,
      "gnsi": gRPCType.GRPC_SERVICE_TYPE_GNSI,
      "gribi": gRPCType.GRPC_SERVICE_TYPE_GRIBI,
      "p4Runtime": gRPCType.GRPC_SERVICE_TYPE_P4RT,
      }

def getIpAddr( connInfo ):
   '''
      Helper function to split IP and port number in the connection info if present
      IPv4 formats:
         a.b.c.d
         a.b.c.d:<port>
         [a.b.c.d]
         [a.b.c.d]:<port>
      IPv6 formats:
         a:b:c:d:e:f
         [a:b:c:d:e:f]:<port>
         [a:b:c:d:e:f]
         [a::b]
         [a::b]:<port>
      Ref: https://datatracker.ietf.org/doc/html/rfc2732#section-2

      Handle cases where the connInfo does not have port number specified

      Algorithm:
      1. If we can find addr enclosed in braces, extract and return it.
      2. If we don't find a colon, return the IPv4 address.
      3. If there is a single colon, then its an IPv4 address with port.
      4. Else its an IPv6 address without port.
   '''
   lastBrace = connInfo.rfind( ']' )

   if lastBrace != -1:
      # There's some braces - pull out whatever is inside them to use as our address
      return connInfo[ 1 : lastBrace ]

   lastColon = connInfo.rfind( ':' )
   if lastColon == -1:
      # IPv4 address with no port or brace
      return connInfo

   # This is either an IPv4 address like a.b.c.d:port or its an ipv6 address
   # like a:b:c:d without an port.
   if connInfo.find( ':' ) == lastColon:
      # There's only one colon - so this is an IPv4 address
      return connInfo[ : lastColon ]
   # This is an IPv6 address
   return connInfo

class GnsiAcctzPlugin( AaaPluginLib.Plugin ):
   def __init__( self, aaaConfig, gnsiConfig, acctzData, localUserConfig ):
      traceX( TR_INFO, "GnsiAcctzPlugin initialized" )
      AaaPluginLib.Plugin.__init__( self, aaaConfig, gnsiAcctzMethodName )
      self.aaaConfig_ = aaaConfig
      self.acctzData_ = acctzData
      self.gnsiConfig_ = gnsiConfig
      self.localUserConfig_ = localUserConfig
      self.warmup_ = False
      self.recordLock_ = threading.Lock()

   def setupStateMachine_( self ):
      '''
         This is called once the plugin's mounts are completed.
         Now we can read Sysdb and setup our state machine
      '''
      traceX( TR_INFO, "GnsiAcctzPlugin: setupStateMachine_" )
      minSeqNum = min( self.acctzData_.accountingRecords, default=1 )
      self.acctzData_.lastPurgedSeqNum = minSeqNum - 1
      self.acctzData_.nextSeqNum = max( self.acctzData_.accountingRecords,
                                        default=1 )
      self.warmup_ = True

   def maybePerformPurge_( self ):
      '''
         If we have hit the maximum number of records, delete the oldest record.
      '''
      # Check if we have reached the maximum limit
      if len( self.acctzData_.accountingRecords ) == \
            self.gnsiConfig_.acctzConfig.historyLimit:
         # Perform cleanup
         del self.acctzData_.accountingRecords[
               self.acctzData_.lastPurgedSeqNum + 1 ]
         self.acctzData_.lastPurgedSeqNum += 1

   def postProcessRecords_( self ):
      '''
         This function is called after the record is added,
         1. Increment the next sequence number.
      '''
      self.acctzData_.nextSeqNum += 1

   def getNextSeqNum_( self ):
      '''
         Return the next sequence number
      '''
      return self.acctzData_.nextSeqNum

   def ready( self ):
      # The plugin is always ready, its not dependent on any additional
      # configuration like external RADIUS/TACACS servers etc.
      return True

   def warm( self ):
      return self.warmup_

   def getAuthenType_( self, authenType ):
      return authTypeMap.get( authenType, authType.AUTHN_TYPE_UNSPECIFIED )

   def fillSessionData_( self, user, localAddr, localPort, remoteHost, remotePort,
                         tty, privLevel, sshPrincipal, authenType, ipProto,
                         channelId, authStatus, authCause=None, taskIds=None ):
      '''
         Helper function to fill session data, returns a Gnsi::Acctz::SessionInfo
      '''
      sessionInfo = Tac.Value( "Gnsi::Acctz::SessionInfo" )
      if localAddr:
         if localAddr == localHost:
            sessionInfo.localAddr = Tac.Value( 'Arnet::IpGenAddr', loopbackIp )
         else:
            sessionInfo.localAddr = Tac.Value( 'Arnet::IpGenAddr',
                  getIpAddr( localAddr ) )
      if localPort:
         sessionInfo.localPort = Tac.Value( 'Arnet::Port', localPort )
      if remoteHost:
         if remoteHost == localHost:
            sessionInfo.remoteAddr = Tac.Value( 'Arnet::IpGenAddr', loopbackIp )
         else:
            sessionInfo.remoteAddr = Tac.Value( 'Arnet::IpGenAddr',
                  getIpAddr( remoteHost ) )
      if remotePort:
         sessionInfo.remotePort = Tac.Value( 'Arnet::Port', remotePort )
      if ipProto:
         sessionInfo.ipProto = Tac.Value( 'Arnet::IpProto', ipProto )
      if sshPrincipal:
         sessionInfo.sshPrincipal = sshPrincipal
      sessionInfo.channelId = channelId
      sessionInfo.tty = tty
      sessionInfo.identity = user
      sessionInfo.authType = self.getAuthenType_( authenType )
      sessionInfo.authStatus = authStatus
      if authCause:
         sessionInfo.authCause = authCause
      if taskIds:
         sessionInfo.taskIds = str( taskIds )
      return sessionInfo

   def getGrpcType_( self, **kwargs ):
      if gType := kwargs.get( "gRPCType" ):
         return gRPCTypeMap.get( gType, gRPCType.GRPC_SERVICE_TYPE_UNSPECIFIED )
      else:
         assert False, f"does not have gRPCType {kwargs}"

   def cmdAuthzEnabled( self, privlevel ):
      '''
         Utility function to check if command Aaa authorization is enabled
      '''
      ml = self.aaaConfig_.authzMethod.get( f"command{privlevel:02d}" )
      if not ml:
         return False
      return not ( len( ml.defaultMethod ) == 1 and ml.defaultMethod[ 0 ] == "None" )

   def gNSIAuthzEnabled( self ):
      '''
         Utility function to check if gNSI.Authz is enabled.
      '''
      return self.gnsiConfig_.service.authz

   def fillCommandAcct_( self, user, session, privLevel, tokens, cmdType,
                         timestamp, cmdSuccess=True, authzDetail=None, **kwargs ):
      '''
         Helper function to populate AcctzRecord for success/failed
         command invocation. Populate the below in an accounting record
         1. AccountingInfo
         2. SessionInfo
         3. CommandServiceInfo or GrpcServiceInfo
      '''
      # AccountingInfo
      acctInfo = Tac.Value( "Gnsi::Acctz::AccountingInfo" )
      acctInfo.sessionStatus = Tac.Type(
            "Gnsi::Acctz::SessionStatus" ).SESSION_STATUS_OPERATION
      acctInfo.role = str( privLevel )

      # SessionInfo
      authStatus = Tac.Type( "Gnsi::Acctz::AuthStatus" ).AUTHN_STATUS_SUCCESS

      # Command service or Grpc service
      privilegeEscalation = False
      if cmdType == "cli":
         commandServiceInfo = Tac.Value( "Gnsi::Acctz::CommandServiceInfo" )
         commandServiceInfo.cmdType = Tac.Type(
               "Gnsi::Acctz::CommandServiceType" ).CMD_SERVICE_TYPE_CLI
         # For CLI command, tokens is a list of all keyswords of a CLI command
         # record the first element as the cmd and the rest as cmd arguments
         commandServiceInfo.cmd = tokens[ 0 ]
         commandServiceInfo.isCmdTruncated = False
         commandServiceInfo.args = " ".join( tokens[ 1 : ] )
         commandServiceInfo.isArgsTruncated = False

         # If privilege-escalation happens by the "enable" command, flag it
         if tokens[ 0 ] == "enable":
            privilegeEscalation = True

         # Populate authz information ONLY if its configured
         if self.cmdAuthzEnabled( privLevel ):
            # Result of the command
            if cmdSuccess:
               commandServiceInfo.authzStatus = Tac.Type(
                     "Gnsi::Acctz::AuthzStatus" ).AUTHZ_STATUS_PERMIT
            else:
               commandServiceInfo.authzStatus = Tac.Type(
                     "Gnsi::Acctz::AuthzStatus" ).AUTHZ_STATUS_DENY
               if authzDetail:
                  commandServiceInfo.authzDetail = authzDetail

         serviceInfo = Tac.Value( "Gnsi::Acctz::ServiceInfo" )
         serviceInfo.commandServiceInfo = commandServiceInfo
      elif cmdType == "gRPC":
         grpcServiceInfo = Tac.Value( "Gnsi::Acctz::GrpcServiceInfo" )
         grpcServiceInfo.grpcType = self.getGrpcType_( **kwargs )

         # gRPCName
         if gRpcName := kwargs.get( "gRPCName" ):
            grpcServiceInfo.grpcName = gRpcName

         # gRPCPayload
         if gRpcPayload := kwargs.get( "gRPCPayload" ):
            grpcServiceInfo.payload = gRpcPayload

         # gRPCPayloadTruncated
         if gRpcPayloadTruncated := kwargs.get( "gRPCPayloadTruncated" ):
            grpcServiceInfo.isTruncated = gRpcPayloadTruncated

         authzConfigured = False
         if gRPCAuthz := kwargs.get( "gRPCAuthz" ):
            authzConfigured = gRPCAuthz

         # Populate authz information ONLY if its configured
         if ( authzConfigured or self.gNSIAuthzEnabled() ):
            # Result of the command
            if cmdSuccess:
               grpcServiceInfo.authzStatus = Tac.Type(
                     "Gnsi::Acctz::AuthzStatus" ).AUTHZ_STATUS_PERMIT
            else:
               grpcServiceInfo.authzStatus = Tac.Type(
                     "Gnsi::Acctz::AuthzStatus" ).AUTHZ_STATUS_DENY
               if authzDetail:
                  grpcServiceInfo.authzDetail = authzDetail

         serviceInfo = Tac.Value( "Gnsi::Acctz::ServiceInfo" )
         serviceInfo.grpcServiceInfo = grpcServiceInfo
      else:
         # Assert if invalid cmdType later
         return

      sessionInfo = self.fillSessionData_( user, session.localAddr,
                                           session.localPort,
                                           session.remoteHost, session.remotePort,
                                           session.tty, session.privilegeLevel,
                                           session.sshPrincipal,
                                           session.authenSource,
                                           6, # Proto=TCP
                                           "0", # channel-id
                                           authStatus, taskIds=session.id )

      with self.recordLock_:
         ar = Tac.Value( "Gnsi::Acctz::AccountingRecord", self.getNextSeqNum_() )
         ar.ts = timestamp
         ar.sessionInfo = sessionInfo
         # If the command involved privilege escalation then update SessionStatus
         if privilegeEscalation:
            acctInfo.sessionStatus = Tac.Type(
                  "Gnsi::Acctz::SessionStatus" ).SESSION_STATUS_ENABLE
         ar.accountInfo = acctInfo
         ar.serviceInfo = serviceInfo

         # Check if we need to purge records
         self.maybePerformPurge_()

         self.acctzData_.accountingRecords.addMember( ar )

         # Post processing records
         self.postProcessRecords_()

   def sendCommandAcct( self, method, user, session, privlevel, timestamp, tokens,
                        cmdType=None, **kwargs ):
      '''
         Record a successful command invocation
      '''
      if not GnsiToggle.toggleOCGNSIAcctzEnabled():
         return

      if privlevel == DEFAULT_PRIVILEGE_LEVEL:
         privlevel = self.getPrivilegeLevel_( user, session )

      self.fillCommandAcct_( user, session, privlevel, tokens, cmdType, timestamp,
                             cmdSuccess=True, **kwargs )

   def fillShellAcct_( self, user, localAddr, localPort, remoteHost, remotePort,
                       tty, privLevel, sshPrincipal, authenType, action, timestamp,
                       loginStatus=True, authenCause=None, taskIds=None ):
      '''
         Helper function to fill Shell accounting. Populate AccountingInfo and
         SessionInfo of an accounting record.
      '''
      # AccountingInfo
      acctInfo = Tac.Value( "Gnsi::Acctz::AccountingInfo" )
      if action == "start":
         acctInfo.sessionStatus = Tac.Type(
               "Gnsi::Acctz::SessionStatus" ).SESSION_STATUS_LOGIN
      elif action == "stop":
         acctInfo.sessionStatus = Tac.Type(
               "Gnsi::Acctz::SessionStatus" ).SESSION_STATUS_LOGOUT
      else:
         assert False, f"Invalid action: {action}"
      acctInfo.role = str( privLevel )

      # SessionInfo
      if loginStatus:
         authStatus = Tac.Type( "Gnsi::Acctz::AuthStatus" ).AUTHN_STATUS_SUCCESS
      else:
         authStatus = Tac.Type( "Gnsi::Acctz::AuthStatus" ).AUTHN_STATUS_FAIL
      sessionInfo = self.fillSessionData_( user, localAddr, localPort, remoteHost,
                                           remotePort, tty, privLevel,
                                           sshPrincipal, authenType,
                                           6, # Proto=TCP
                                           "0", # channel-id
                                           authStatus, authenCause, taskIds=taskIds )

      with self.recordLock_:
         ar = Tac.Value( "Gnsi::Acctz::AccountingRecord", self.getNextSeqNum_() )
         ar.ts = timestamp
         ar.accountInfo = acctInfo
         ar.sessionInfo = sessionInfo

         # Check if we need to purge records
         self.maybePerformPurge_()

         self.acctzData_.accountingRecords.addMember( ar )

         # Post processing records
         self.postProcessRecords_()

   def getPrivilegeLevel_( self, user, session ):
      '''
         Extract the initial privilege-level from local user config if present,
         else look at SessionData which will have privilege-level from remote
         servers.
      '''
      privilegeLevel = 1  # default privilege level
      # Get local user if present
      if localUser := self.localUserConfig_.acct.get( user ):
         privilegeLevel = localUser.privilegeLevel
      else:
         # Check if privilege level is provided by remote server in sessionData.
         # There can be multiple methods configured TACACS/RADIUS/local etc
         for prop in session.property.values():
            # We have multiple variants of the privilegeLevel sent by TACACS/RAIDUS,
            # we need to check all values before we start using a unified value
            for key in [ AaaPluginLib.privilegeLevel, 'priv-lvl', 'priv_lvl' ]:
               priv = prop.attr.get( key )
               if priv is not None:
                  return priv
      return privilegeLevel

   def sendShellAcct( self, method, user, session, action, startTime,
                      elapsedTime=None ):
      '''
         Populate AccountingInfo and SessionInfo of an accounting record.
      '''
      if not GnsiToggle.toggleOCGNSIAcctzEnabled():
         return

      self.fillShellAcct_( user, session.localAddr, session.localPort,
                           session.remoteHost, session.remotePort,
                           session.tty, self.getPrivilegeLevel_( user, session ),
                           session.sshPrincipal, session.authenSource, action,
                           startTime, loginStatus=True, taskIds=session.id )

   def sendFailedShellAcct( self, user, localAddr, localPort, remoteHost, remotePort,
                            tty, privLevel, sshPrincipal, authenSource, failureCause,
                            timestamp ):
      '''
         Populate AccountingInfo and SessionInfo of an accounting record.
         Action is always start since this indicates a failed login attempt.
      '''
      if not GnsiToggle.toggleOCGNSIAcctzEnabled():
         return

      # If privLevel is 0, then this was due to failed login attempt over SSH,
      # check if there is an local account where a privilege is configured for
      # the user else report default privilege level. This information is not
      # important during authentication but is revelant during authorization.
      if privLevel == 0:
         # Get local user if present
         if localUser := self.localUserConfig_.acct.get( user ):
            privLevel = localUser.privilegeLevel
         else:
            privLevel = 1

      self.fillShellAcct_( user, localAddr, localPort, remoteHost, remotePort, tty,
                           privLevel, sshPrincipal, authenSource, "start", timestamp,
                           loginStatus=False, authenCause=failureCause )

   def sendFailedCommandAcct( self, user, privLevel, session, tokens, cmdType,
                              timestamp, authzDetail, **kwargs ):
      '''
         Record a failed command invocation
      '''
      if not GnsiToggle.toggleOCGNSIAcctzEnabled():
         return

      if privLevel == 0:
         privLevel = self.getPrivilegeLevel_( user, session )

      self.fillCommandAcct_( user, session, privLevel, tokens, cmdType, timestamp,
                             cmdSuccess=False, authzDetail=authzDetail, **kwargs )

   def authorizeShell( self, method, user, session ):
      '''
         Not used, providing a default implementation since its abstract in Plugin
      '''
      pass # pylint: disable=unnecessary-pass

   def authorizeShellCommand( self, method, user, session, mode, privlevel, tokens ):
      '''
         Not used, providing a default implementation since its abstract in Plugin
      '''
      pass # pylint: disable=unnecessary-pass

   # pylint: disable-next=redefined-builtin
   def createAuthenticator( self, method, type, service, remoteHost, remoteUser,
                            tty, user=None, privLevel=0 ):
      '''
         Not used, providing a default implementation since its abstract in Plugin
      '''
      pass # pylint: disable=unnecessary-pass

class GnsiConfigReactor( Tac.Notifiee ):
   '''
      React to gnsi configuration
   '''
   notifierTypeName = "Gnsi::Config"

   def __init__( self, gnsiConfig, aaaStatus, acctzData ):
      traceX( TR_INFO, "GnsiConfig reactor initialized" )
      self.gnsiConfig_ = gnsiConfig
      self.aaaStatus_ = aaaStatus
      self.acctzData_ = acctzData
      Tac.Notifiee.__init__( self, gnsiConfig )
      # Check configuration changes which happened before the reactor was init
      self.handleGnsiConfig()

   @Tac.handler( 'service' )
   def handleGnsiConfig( self ):
      traceX( TR_DEBUG, "Gnsi configuration updated" )
      if self.gnsiConfig_.service.acctz:
         traceX( TR_DEBUG, "Gnsi acctz enabled" )
         self.aaaStatus_.extraAcctMethods.add( gnsiAcctzMethodName )
      else:
         traceX( TR_DEBUG, "Gnsi acctz disabled" )
         self.aaaStatus_.extraAcctMethods.remove( gnsiAcctzMethodName )

         # Clear all records in Sysdb since gNSI.Acctz was disabled
         self.acctzData_.nextSeqNum = 1
         self.acctzData_.lastPurgedSeqNum = 0
         self.acctzData_.accountingRecords.clear()

def Plugin( ctx ):
   mountGroup = ctx.entityManager.mountGroup()
   gnsiConfig = mountGroup.mount( "mgmt/gnsi/config", "Gnsi::Config", "r" )
   localUserConfig = mountGroup.mount( 'security/aaa/local/config',
                                       'LocalUser::Config', 'r' )
   aaaConfig = ctx.aaaAgent.config
   aaaStatus = ctx.aaaAgent.status
   acctzData = mountGroup.mount( Cell.path( 'security/aaa/gnsiAcctz' ),
                                 "Gnsi::Acctz::AccountingData", "wf" )
   plugin = GnsiAcctzPlugin( aaaConfig, gnsiConfig, acctzData, localUserConfig )

   def finish_():
      global gnsiConfigReactor_
      # Setup reactors for gNSI.acctz config
      gnsiConfigReactor_ = GnsiConfigReactor( gnsiConfig, aaaStatus, acctzData )
      plugin.setupStateMachine_()

   mountGroup.close( finish_ )
   return plugin
