# Copyright (c) 2008-2010, 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=use-dict-literal
# pylint: disable=consider-using-f-string

"""
Provides a mechanism for other packages to extend the list of authentication,
authorization and accounting methods that can be used in the Aaa CLI.
For example, the Tacacs package can provide the "group tacacs+" method by
calling registerGroup with appropriate parameters.
"""

import BasicCli
import CliCommand
import CliGlobal
import CliMatcher
import CliParser
import ConfigMount
import HostnameCli
import LazyMount
import collections
from CliMode.Aaa import ServerGroupMode
from CliPlugin.VrfCli import VrfExprFactory, DEFAULT_VRF
import Tac
import Tracing
import sys
import termios
from AaaPluginLib import hostProtocol
import Cell

traceHandle_ = Tracing.Handle( "AaaCliLib" )
t0 = Tracing.trace0

def getHostIndex( hosts ):
   """ An "index" member is added to both Tacacs and Radius host entry to
   record the order in which servers are configured. This method gives the
   the next index to be assigned to a new server entry.
   """
   if hosts is None:
      indx = 1
   else: 
      sortedHost = sorted( hosts.values(), key=lambda host: host.index )
      indx = sortedHost[ -1 ].index + 1
   return indx

class ActionType:
   AUTHN_LOGIN   =  1
   AUTHN_ENABLE  =  2
   AUTHZ_EXEC    =  4
   AUTHZ_COMMAND =  8
   ACCT_EXEC     = 16
   ACCT_COMMAND  = 32
   ACCT_SYSTEM   = 64
   AUTHN_DOT1X   = 128
   ACCT_DOT1X    = 256
   # some compound bits
   AUTHN = AUTHN_LOGIN|AUTHN_ENABLE
   AUTHZ = AUTHZ_EXEC|AUTHZ_COMMAND
   ACCT = ACCT_EXEC|ACCT_COMMAND|ACCT_SYSTEM

authnMethodListNonGroup = dict()
authnMethodListGroup = dict()
authnDot1xMethodListNonGroup = dict()
authnDot1xMethodListGroup = dict()
authzExecMethodListNonGroup = dict()
authzExecMethodListGroup = dict()
authzCommandMethodListNonGroup = dict()
authzCommandMethodListGroup = dict()
acctExecMethodListNonGroup = dict()
acctExecMethodListGroup = dict()
acctCommandMethodListNonGroup = dict()
acctCommandMethodListGroup = dict()
acctSystemMethodListNonGroup = dict()
acctSystemMethodListGroup = dict()
acctDot1xMethodListNonGroup = dict()
acctDot1xMethodListGroup = dict()

methodMatchObj = object()
methodMatchMax = Tac.Type( "Aaa::MethodIndex" ).max

groupKwMatcher = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'group',
                                      helpdesc='Specify server group' ),
   sharedMatchObj=methodMatchObj,
   maxMatches=methodMatchMax )

def _methodMatcher( keywordDict ):
   return CliMatcher.DynamicKeywordMatcher( lambda mode: keywordDict )

# the group matcher doesn't need maxMatches as the groupKwMatcher already has it
methodGroupMatcher = _methodMatcher

def methodNonGroupMatcher( keywordDict ):
   return CliCommand.Node( matcher=_methodMatcher( keywordDict ),
                           sharedMatchObj=methodMatchObj,
                           maxMatches=methodMatchMax )

authnMethodListNonGroupMatcher = methodNonGroupMatcher( authnMethodListNonGroup )
authnMethodListGroupMatcher = methodGroupMatcher( authnMethodListGroup )
authnDot1xMethodListNonGroupMatcher = \
   methodNonGroupMatcher( authnDot1xMethodListNonGroup )
authnDot1xMethodListGroupMatcher = methodGroupMatcher( authnDot1xMethodListGroup )
authzExecMethodListNonGroupMatcher = \
   methodNonGroupMatcher( authzExecMethodListNonGroup )
authzExecMethodListGroupMatcher = methodGroupMatcher( authzExecMethodListGroup )
authzCommandMethodListNonGroupMatcher = \
   methodNonGroupMatcher( authzCommandMethodListNonGroup )
authzCommandMethodListGroupMatcher = \
   methodGroupMatcher( authzCommandMethodListGroup )
acctExecMethodListNonGroupMatcher = \
   methodNonGroupMatcher( acctExecMethodListNonGroup )
acctExecMethodListGroupMatcher = methodGroupMatcher( acctExecMethodListGroup )
acctCommandMethodListNonGroupMatcher = \
   methodNonGroupMatcher( acctCommandMethodListNonGroup )
acctCommandMethodListGroupMatcher = methodGroupMatcher( acctCommandMethodListGroup )
acctSystemMethodListNonGroupMatcher = \
   methodNonGroupMatcher( acctSystemMethodListNonGroup )
acctSystemMethodListGroupMatcher = methodGroupMatcher( acctSystemMethodListGroup )
acctDot1xMethodListNonGroupMatcher = \
   methodNonGroupMatcher( acctDot1xMethodListNonGroup )
acctDot1xMethodListGroupMatcher = methodGroupMatcher( acctDot1xMethodListGroup )

def addAuthnMethod( name, desc ):
   authnMethodListNonGroup[ name ] = desc

def addAuthnDot1xMethod( name, desc ):
   authnDot1xMethodListNonGroup[ name ] = desc

def addAuthzExecMethod( name, desc ):
   authzExecMethodListNonGroup[ name ] = desc

def addAuthzCommandMethod( name, desc ):
   authzCommandMethodListNonGroup[ name ] = desc

def addAcctExecMethod( name, desc ):
   acctExecMethodListNonGroup[ name ] = desc

def addAcctCommandMethod( name, desc ):
   acctCommandMethodListNonGroup[ name ] = desc

def addAcctSystemMethod( name, desc ):
   acctSystemMethodListNonGroup[ name ] = desc

def addAcctDot1xMethod( name, desc ):
   acctDot1xMethodListNonGroup[ name ] = desc

def addMethod( name, descTemplate,
               suppType=ActionType.AUTHN|ActionType.AUTHZ|ActionType.ACCT|\
                        ActionType.AUTHN_DOT1X|ActionType.ACCT_DOT1X ):
   if suppType & ActionType.AUTHN:
      addAuthnMethod( name, descTemplate % "authentication" )
   if suppType & ActionType.AUTHN_DOT1X:
      addAuthnDot1xMethod( name, descTemplate % "authentication" )
   if suppType & ActionType.AUTHZ_EXEC:
      addAuthzExecMethod( name, descTemplate % "authorization" )
   if suppType & ActionType.AUTHZ_COMMAND:
      addAuthzCommandMethod( name, descTemplate % "authorization" )
   if suppType & ActionType.ACCT_EXEC:
      addAcctExecMethod( name, descTemplate % "accounting" )
   if suppType & ActionType.ACCT_COMMAND:
      addAcctCommandMethod( name, descTemplate % "accounting" )
   if suppType & ActionType.ACCT_SYSTEM:
      addAcctSystemMethod( name, descTemplate % "accounting" )
   if suppType & ActionType.ACCT_DOT1X:
      addAcctDot1xMethod( name, descTemplate % "accounting" )

addMethod( 'local', "Use local database for %s",
           suppType=ActionType.AUTHN | ActionType.AUTHZ )
addMethod( 'none', "No %s (always succeeds)",
           suppType=ActionType.AUTHN | ActionType.AUTHZ )
addMethod( 'logging', "Use syslog for %s",
           suppType=ActionType.ACCT | ActionType.ACCT_DOT1X )

# The map from groupType to various group-specific values
# populated by registerGroup()
_serverGroupDB = { }

#-------------------------------------------------------------------------------
#     aaa group server [tacacs+|radius|...] <server-group-name>
#-------------------------------------------------------------------------------

class ServerGroupConfigMode( ServerGroupMode, BasicCli.ConfigModeBase ):
   #----------------------------------------------------------------------------
   # This is meant to be a base class only.
   # Attributes required for mode class should be defined by the subclass
   #----------------------------------------------------------------------------
   def getHostProto( self, groupType ):
      assert groupType != 'unknown'
      if groupType == "ldap":
         proto = hostProtocol.protoLdap
      elif groupType == "tacacs":
         proto = hostProtocol.protoTacacs
      elif groupType == "radius":
         proto = hostProtocol.protoRadius
      return proto

   def __init__( self, parent, session, group, groupName ):
      self.groupName = groupName
      groupType = _serverGroupDB[ group.groupType ].cliToken
      self.defaultPort = _serverGroupDB[ group.groupType ].authport.defaultPort
      self.protocol = self.getHostProto( group.groupType )
      if _serverGroupDB[ group.groupType ].acctport:
         self.defaultAcctPort = \
            _serverGroupDB[ group.groupType ].acctport.defaultPort
      else:
         self.defaultAcctPort = 0
      if _serverGroupDB[ group.groupType ].tlsport:
         self.defaultTlsPort = _serverGroupDB[ group.groupType ].tlsport.defaultPort
      else:
         self.defaultTlsPort = 0
      ServerGroupMode.__init__( self, groupType, groupName )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

   def setServer( self, hostname, port=None, vrf=DEFAULT_VRF,
                  acctPort=None, tlsEnabled=None ):
      assert vrf != ''
      if vrf is None:
         vrf = DEFAULT_VRF
      if not tlsEnabled:
         if port is None:
            port = self.defaultPort
         if acctPort is None:
            acctPort = self.defaultAcctPort
         proto = self.protocol
      else:
         acctPort = 0
         if port is None:
            port = self.defaultTlsPort
         proto = hostProtocol.protoRadsec

      spec = Tac.Value( "Aaa::HostSpec", hostname=hostname, port=port,
                        acctPort=acctPort, vrf=vrf, protocol=proto )
      for m in ( configAaa( self ).hostgroup[
            self.groupName ].member ).values():
         if m.spec == spec:
            break
      else:
         member = Tac.Value( "Aaa::HostGroupMember", spec )
         configAaa( self ).hostgroup[ self.groupName ].member.enq( member )

   def noServer( self, hostname, port=None, vrf=DEFAULT_VRF,
                 acctPort=None, tlsEnabled=None ):
      assert vrf != ''
      if vrf is None:
         vrf = DEFAULT_VRF
      if not tlsEnabled:
         if port is None:
            port = self.defaultPort
         if acctPort is None:
            acctPort = self.defaultAcctPort
         proto = self.protocol
      else:
         acctPort = 0
         if port is None:
            port = self.defaultTlsPort
         proto = hostProtocol.protoRadsec

      spec = Tac.Value( "Aaa::HostSpec", hostname=hostname, port=port,
                        acctPort=acctPort, vrf=vrf, protocol=proto )
      for i, m in ( configAaa( self ).hostgroup[
            self.groupName ].member ).items():
         if m.spec == spec:
            del configAaa( self ).hostgroup[ self.groupName ].member[ i ]
            break

configMode = BasicCli.GlobalConfigMode

aaaKwMatcher = CliMatcher.KeywordMatcher(
   'aaa',
   helpdesc="Authentication, Authorization and Accounting" )

gv = CliGlobal.CliGlobal( dict( config=None,
                                counterConfig=None,
                                status=None ) )

def counterConfigAaaIs( entityManager ):
   gv.counterConfig = LazyMount.mount( entityManager,
                                       "security/aaa/counterConfig",
                                       "Aaa::CounterConfig", "w" )

def counterConfigAaa( mode ):
   return gv.counterConfig

def configAaaIs( entityManager ):
   gv.config = ConfigMount.mount( entityManager,
                                  "security/aaa/config",
                                  "Aaa::Config", "w" )

def configAaa( mode ):
   return gv.config

def statusAaaIs( entityManager ):
   gv.status = LazyMount.mount( entityManager,
                                Cell.path( 'security/aaa/status' ),
                                "Aaa::Status", "r" )

def statusAaa( mode ):
   return gv.status

def initLibrary( entityManager ):
   configAaaIs( entityManager )
   counterConfigAaaIs( entityManager )
   statusAaaIs( entityManager )

def getGroupTypeFromToken( groupToken ):
   for k in _serverGroupDB: # pylint: disable=consider-using-dict-items
      if _serverGroupDB[ k ].cliToken == groupToken:
         return k
   return 'unknown'

def getCliDisplayFromGroup( groupType ):
   return _serverGroupDB[ groupType ].cliDisplay

# This is for the new parser
serverGroupNameMatcher = CliMatcher.DynamicNameMatcher(
   lambda mode: configAaa( mode ).hostgroup,
   helpdesc='Server-group name',
   priority=CliParser.PRIO_LOW )

# Radius server group name matcher
radiusServerGroupNameMatcher = CliMatcher.DynamicNameMatcher(
   lambda mode: [ k for ( k, v ) in configAaa( mode ).hostgroup.items()
                  if v.groupType == 'radius' ],
   helpdesc='RADIUS server-group name' )

# return a dynamic config mode based on the token entered
def _createServerGroupConfigModeType( groupToken ):
   className = '%sServerGroupConfigMode' % ( groupToken )
   return type( className, ( ServerGroupConfigMode, ),
                dict( name='Server-group %s' % ( groupToken ) ) )

def _gotoServerGroupConfigMode( mode, groupToken, groupName ):
   # do not allow 'groupToken' to be used as groupName
   if groupToken == groupName:
      mode.addError( "group \'%s\' is reserved" % groupToken )
      return

   groups = configAaa( mode ).hostgroup
   # find the group type from groupToken
   groupType = getGroupTypeFromToken( groupToken )
   assert groupType != 'unknown'

   if groupName in groups:
      group = groups[ groupName ]
      # if the group exists but is not an intended group, ignore the command
      if group.groupType != groupType:
         gtype = _serverGroupDB[ group.groupType ].cliToken
         mode.addError( f"Group \'{groupName}\' already exists for {gtype}" )
         return
   else:
      group = groups.newMember( groupName )
      group.groupType = groupType

   childMode = mode.childMode( _serverGroupDB[ groupType ].configModeType,
                               group=group, groupName=groupName )
   mode.session_.gotoChildMode( childMode )

def _noServerGroup( mode, groupToken, groupName ):
   groups = configAaa( mode ).hostgroup
   groupType = getGroupTypeFromToken( groupToken )
   assert groupType != 'unknown'

   if ( groupName in groups and
        groups[ groupName ].groupType == groupType ):
      del groups[ groupName ]

# simple type for port registration, see [Tacacs|Radius]Group.py
AaaPortInfo = collections.namedtuple( 'AaaPortInfo',
                                      'portToken portHelp defaultPort' )

ServerGroupInfo = collections.namedtuple(
   'ServerGroupInfo',
   'cliToken cliDisplay authport acctport tlsport configModeType'
   ' suppType vrfSupport' )

def registerGroup( groupType, cliToken, cliDisplay, authport,
                   acctport, tlsport, suppType, vrfSupport ):
   """
   Register a Cli plugin with its method list and various Aaa group CLI commands.

   groupType:   The HostGroupType enum defined in Aaa.tac to uniquely identify
                this protocol
   cliToken:    The token used in the CLI commands, such as 'tacacs+'
   cliDisplay:  How to display the protocol in CLI, such as 'TACACS+'
   authport/acctport: Authentication and accouting port information
   tlsport:     TLS port information for RADIUS
   suppType:    Which types of authn/authz/acct are supported
   vrfSupport:  Is VRF supported, True/False
   """
   assert not groupType in _serverGroupDB
   configModeType = _createServerGroupConfigModeType( cliToken )
   _serverGroupDB[ groupType ] = ServerGroupInfo( cliToken,
                                                  cliDisplay,
                                                  authport,
                                                  acctport,
                                                  tlsport,
                                                  configModeType,
                                                  suppType,
                                                  vrfSupport )

   # register for method list
   helpdesc = "Use list of all defined %s hosts" % cliDisplay

   authnMethodListGroup[ cliToken ] = helpdesc

   if suppType & ActionType.AUTHN_DOT1X:
      authnDot1xMethodListGroup[ cliToken ] = helpdesc
   if suppType & ActionType.AUTHZ_EXEC:
      authzExecMethodListGroup[ cliToken ] = helpdesc
   if suppType & ActionType.AUTHZ_COMMAND:
      authzCommandMethodListGroup[ cliToken ] = helpdesc
   if suppType & ActionType.ACCT_EXEC:
      acctExecMethodListGroup[ cliToken ] = helpdesc
   if suppType & ActionType.ACCT_COMMAND:
      acctCommandMethodListGroup[ cliToken ] = helpdesc
   if suppType & ActionType.ACCT_SYSTEM:
      acctSystemMethodListGroup[ cliToken ] = helpdesc
   if suppType & ActionType.ACCT_DOT1X:
      acctDot1xMethodListGroup[ cliToken ] = helpdesc

   # We cannot use serverGroupNameMatcher as it auto-completes all server groups.
   # We need to only auto-complete on server groups specific to our type.
   myServerGroupNameMatcher = CliMatcher.DynamicNameMatcher(
      lambda mode: [ k for ( k, v ) in configAaa( mode ).hostgroup.items()
                     if v.groupType == groupType ],
      helpdesc='Server-group name' )

   # the server-group config mode
   class ServerGroupCmd( CliCommand.CliCommandClass ):
      syntax = "aaa group server %s NAME" % cliToken
      noOrDefaultSyntax = syntax
      data = {
         'aaa' : aaaKwMatcher,
         'group' : 'Group definitions',
         'server' : 'AAA server-group definitions',
         cliToken : '%s server-group definition' % cliDisplay,
         'NAME' : myServerGroupNameMatcher }

      @staticmethod
      def handler( mode, args ):
         _gotoServerGroupConfigMode( mode, cliToken, args[ 'NAME' ] )

      @staticmethod
      def noOrDefaultHandler( mode, args ):
         _noServerGroup( mode, cliToken, args[ 'NAME' ] )

   BasicCli.GlobalConfigMode.addCommandClass( ServerGroupCmd )

   #-------------------------------------------------------------------------------
   # In tacacs+ server group config mode:
   #
   #    [no] server <ip-addr-or-hostname> [VRF] [port <0-65535>]
   #
   # In radius server group config mode:
   #
   #    [no] server <ip-addr-or-hostname> [acct-port <0-65535>] [auth-port <0-65535>]
   #-------------------------------------------------------------------------------
   class ServerCmd( CliCommand.CliCommandClass ):
      syntax = "server HOSTNAME"
      data = {
         'server' : 'Add a % s server to the server - group' % cliDisplay,
         'HOSTNAME' : HostnameCli.IpAddrOrHostnameMatcher(
            helpname='WORD',
            helpdesc='Hostname or IP address of %s server' % cliDisplay,
            ipv6=True )
      }
      if vrfSupport:
         syntax += ' [ VRF ]'
         data[ 'VRF' ] = VrfExprFactory( helpdesc='VRF for this server' )

      _portNumber = CliMatcher.IntegerMatcher( 1, 65535,
                                               helpdesc="Number of the port to use" )
      if not tlsport:
         if authport is not None and not tlsport:
            syntax += " [ %s AUTHPORT ]" % authport.portToken
            data[ authport.portToken ] = authport.portHelp
            data[ 'AUTHPORT' ] = _portNumber

         if acctport is not None:
            syntax += " [ %s ACCTPORT ]" % acctport.portToken
            data[ acctport.portToken ] = acctport.portHelp
            data[ 'ACCTPORT' ] = _portNumber
      else:
         syntax += "[ ( [ %s AUTHPORT ][ %s ACCTPORT ] ) |" \
                   "( tls [ %s TLSPORT ] ) ]" % (
                      authport.portToken, acctport.portToken, tlsport.portToken )
         data[ authport.portToken ] = authport.portHelp
         data[ 'AUTHPORT' ] = _portNumber
         data[ acctport.portToken ] = acctport.portHelp
         data[ 'ACCTPORT' ] = _portNumber
         data[ 'tls' ] = 'TLS radius server'
         data[ tlsport.portToken ] = tlsport.portHelp
         data[ 'TLSPORT' ] = _portNumber

      noOrDefaultSyntax = syntax

      @staticmethod
      def handler( mode, args ):
         tlsEnabled=args.get( 'tls' )
         if tlsEnabled:
            port=args.get( 'TLSPORT' )
         else:
            port=args.get( 'AUTHPORT' )
         mode.setServer( args[ 'HOSTNAME' ],
                         port=port,
                         vrf=args.get( 'VRF', DEFAULT_VRF ),
                         acctPort=args.get( 'ACCTPORT' ),
                         tlsEnabled=tlsEnabled )

      @staticmethod
      def noOrDefaultHandler( mode, args ):
         tlsEnabled=args.get( 'tls', None )
         if tlsEnabled:
            port=args.get( 'TLSPORT' )
         else:
            port=args.get( 'AUTHPORT' )
         mode.noServer( args[ 'HOSTNAME' ],
                        port=port,
                        vrf=args.get( 'VRF', DEFAULT_VRF ),
                        acctPort=args.get( 'ACCTPORT' ),
                        tlsEnabled=tlsEnabled )

   configModeType.addCommandClass( ServerCmd )

def groupTypeSupported( groupType ):
   """Returns a bitmap of ActionType supported by the groupType"""
   return _serverGroupDB[ groupType ].suppType

# the following is used by tests to monkey patch getpass.getpass()
# as ConfigAgent in workspace has a controlling terminal and won't work.
def myGetpass( prompt="Password: " ):
   fd = sys.stdin.fileno()
   old = termios.tcgetattr( fd )
   new = old.copy()
   new[ 3 ] &= ~termios.ECHO
   try:
      termios.tcsetattr( fd, termios.TCSADRAIN, new )
      passwd = input( prompt )
      print()
   finally:
      termios.tcsetattr( fd, termios.TCSADRAIN, old )
   return passwd

def patchGetpass():
   import getpass # pylint: disable=import-outside-toplevel
   getpass.getpass = myGetpass
