# Copyright (c) 2008-2011, 2013-2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Cell
import CliDynamicSymbol
import CliGlobal
from CliPlugin.VrfCli import DEFAULT_VRF, DEFAULT_VRF_OLD
import AaaCliLib
from AaaPluginLib import hostProtocol
import Ark
import ConfigMount
import DscpCliLib
import HostnameCli
import Intf.Log
import LazyMount
import Radius
import RadiusGroup
import ReversibleSecretCli
import Tac

RadiusModel = CliDynamicSymbol.CliDynamicPlugin( "RadiusModel" )

gv = CliGlobal.CliGlobal( aaaConfig=None,
                          radiusConfig=None,
                          radiusStatus=None,
                          radiusCounterConfig=None,
                          radiusInputStatus=None,
                          dscpConfig=None,
                          identityDot1xStatus=None )

def radiusHost( mode, hostname, vrf, port, acctPort, tlsEnabled, create=False ):
   hosts = gv.radiusConfig.host
   assert vrf != ""

   if tlsEnabled:
      proto = hostProtocol.protoRadsec
   else:
      proto = hostProtocol.protoRadius

   spec = Tac.Value( "Aaa::HostSpec", hostname=hostname, port=port,
                     acctPort=acctPort, vrf=vrf, protocol=proto )

   if spec in hosts:
      host = hosts[ spec ]
   elif create:
      host = hosts.newMember( spec )
      if host:
         host.index = AaaCliLib.getHostIndex( hosts )
   else:
      host = None
   if host is not None:
      assert host.hostname == hostname
      assert host.vrf == vrf
      assert host.acctPort == acctPort
      assert host.protocol == proto
      assert host.port == port
   return host

#-------------------------------------------------------------------------------
# "show radius" in enable mode
#-------------------------------------------------------------------------------

def showRadiusHost( host, statusCounters ):
   ret1 = RadiusModel.RadiusStats()
   sslProfile = host.sslProfile if host.sslProfile else gv.radiusConfig.sslProfile
   authPort = 0
   tlsPort = 0
   if host.protocol == hostProtocol.protoRadsec:
      tlsPort = host.port
   else:
      authPort = host.port
   ret1.serverInfo = RadiusModel.ServerInfo(
      hostname=host.hostname,
      authport=authPort,
      acctport=host.acctPort,
      dynAuthPort=gv.radiusConfig.dynAuthPort,
      dot1xDynAuthEnabled=gv.identityDot1xStatus.dot1xDynAuthEnabled,
      tlsport=tlsPort,
      sslProfile=sslProfile )

   if host.vrf != DEFAULT_VRF:
      ret1.serverInfo.vrf = host.vrf
   ret1.messagesSent = statusCounters.authnMessagesSent
   ret1.messagesReceived = statusCounters.authnMessagesReceived
   ret1.requestsAccepted = statusCounters.authnAcceptsReceived
   ret1.requestsRejected = statusCounters.authnRejectsReceived
   ret1.requestsTimeout = statusCounters.authnRequestsTimeout
   ret1.requestsRetransmitted = statusCounters.authnRequestsRetransmitted
   ret1.badResponses = statusCounters.authnBadResponses
   ret1.connectionErrors = statusCounters.authnConnectionErrors
   ret1.dnsErrors = statusCounters.authnHostUnresolvable
   ret1.coaRequestsReceived = statusCounters.coaRequestsReceived
   ret1.dmRequestsReceived = statusCounters.dmRequestsReceived
   ret1.coaAckSent = statusCounters.coaAckResponses
   ret1.dmAckSent = statusCounters.dmAckResponses
   ret1.coaNakSent = statusCounters.coaNakResponses
   ret1.dmNakSent = statusCounters.dmNakResponses
   ret1.acctStartsSent = statusCounters.acctStartsSent
   ret1.acctInterimUpdatesSent = statusCounters.acctInterimUpdatesSent
   ret1.acctStopsSent = statusCounters.acctStopsSent
   return ret1

def showRadius( mode, args ):
   ret = RadiusModel.ShowRadius()
   for h in sorted( gv.radiusConfig.host.values(), key=lambda host: host.index ):
      ret.radiusServers.append( showRadiusHost( h,
                                                Radius.getHostCounters( h.spec,
                                                gv.radiusStatus,
                                                gv.radiusInputStatus,
                                                useCheckpoint=True ) ) )
   ret.srcIntf = dict( gv.radiusConfig.srcIntfName )
   for k in sorted( gv.aaaConfig.hostgroup.keys() ):
      g = gv.aaaConfig.hostgroup[ k ]
      if g.groupType == 'radius':
         serverGroupDisplay = AaaCliLib.getCliDisplayFromGroup( g.groupType )
         serverGroupName = g.name
         ret.groups[ serverGroupName ] = RadiusModel.ServerGroup()
         ret.groups[ serverGroupName ].serverGroup = serverGroupDisplay
         for m in g.member.values():
            authPort = 0
            tlsPort = 0
            if m.spec.protocol == hostProtocol.protoRadsec:
               tlsPort = m.spec.port
            else:
               authPort = m.spec.port
            ret2 = RadiusModel.ServerInfo()
            ret2.hostname = m.spec.hostname
            ret2.authport = authPort
            ret2.acctport = m.spec.acctPort
            ret2.tlsport = tlsPort
            if m.spec.vrf != DEFAULT_VRF :
               ret2.vrf = m.spec.vrf
            ret.groups[ serverGroupName ].members.append( ret2 )
   ret.lastCounterClearTime = Ark.switchTimeToUtc( gv.radiusStatus.lastClearTime )
   return ret

#-------------------------------------------------------------------------------
# "clear aaa counters radius" in enable mode
#-------------------------------------------------------------------------------
def clearCounters( mode, args ):
   Intf.Log.logClearCounters( "radius" )
   now = Tac.now()
   gv.radiusCounterConfig.clearCounterRequestTime = now
   try:
      Tac.waitFor(
         lambda: gv.radiusStatus.lastClearTime >= now,
         description='RADIUS clear counter request to complete',
         warnAfter=None, sleep=True, maxDelay=0.5, timeout=5 )
   except Tac.Timeout:
      mode.addWarning( "RADIUS counters may not have been reset yet" )

#-------------------------------------------------------------------------------
# "[no] radius-server key <KEY>" in config mode
#-------------------------------------------------------------------------------
def checkKeySize( mode, key ):
   if key and key.clearTextSize() > Radius.MAX_KEY_SIZE:
      mode.addError( f"Maximum key size is {Radius.MAX_KEY_SIZE}" )
      return False
   return True

def setRadiusServerKey( mode, args ):
   key = args.get( '<KEY>' )
   secretProfileName = args.get( '<SECRET_PROFILE>' )
   if key:
      if checkKeySize( mode, key ):
         gv.radiusConfig.key = key
         gv.radiusConfig.secretProfileName = ''
   elif secretProfileName:
      gv.radiusConfig.secretProfileName = secretProfileName
      gv.radiusConfig.key = ReversibleSecretCli.getDefaultSecret()

def noRadiusServerKey( mode, args ):
   if 'key' in args:
      gv.radiusConfig.key = ReversibleSecretCli.getDefaultSecret()
   elif 'shared-secret' in args:
      gv.radiusConfig.secretProfileName = ''

#-------------------------------------------------------------------------------
# "[no] radius-server attribute 32 include-in-access-req format ( ( format NAS )
# | fqdn | hostname | disabled )
# in config mode
#-------------------------------------------------------------------------------
def setRadiusServerAttribute( mode, args ):
   nasId = args.get( 'NAS' )
   if nasId:
      if len( nasId ) > Radius.MAX_NAS_ID_SIZE:
         mode.addError( f"Maximum NAS-Identifier length is "
                        f"{Radius.MAX_NAS_ID_SIZE}" )
         return
      gv.radiusConfig.nasIdType = "custom"
      gv.radiusConfig.nasId = nasId
   elif args.get( 'hostname' ):
      gv.radiusConfig.nasIdType = 'hostname'
      gv.radiusConfig.nasId = ""
   else:
      gv.radiusConfig.nasIdType = 'fqdn'
      gv.radiusConfig.nasId = ""

def noRadiusServerAttribute( mode, args ):
   gv.radiusConfig.nasIdType = "disabled"
   gv.radiusConfig.nasId = ""

#-------------------------------------------------------------------------------
# "[no] radius-server qos dscp <0-63>" in config mode
#-------------------------------------------------------------------------------
def updateDscpRules():
   dscpValue = gv.radiusConfig.dscpValue

   if not dscpValue:
      del gv.dscpConfig.protoConfig[ 'radius' ]
      return

   protoConfig = gv.dscpConfig.newProtoConfig( 'radius' )
   ruleColl = protoConfig.rule
   ruleColl.clear()

   for spec in gv.radiusConfig.host:
      # Traffic connecting to external radius server auth.
      DscpCliLib.addDscpRule( ruleColl, spec.hostname,
                              spec.port, False, spec.vrf,
                              'udp', dscpValue )
      DscpCliLib.addDscpRule( ruleColl, spec.hostname,
                              spec.port, False, spec.vrf,
                              'udp', dscpValue, v6=True )

      # Traffic connecting to internal radius server acct.
      DscpCliLib.addDscpRule( ruleColl, spec.hostname,
                              spec.acctPort, False, spec.vrf,
                              'udp', dscpValue )
      DscpCliLib.addDscpRule( ruleColl, spec.hostname,
                              spec.acctPort, False, spec.vrf,
                              'udp', dscpValue, v6=True )

def setDscp( mode, args ):
   gv.radiusConfig.dscpValue = args[ 'DSCP' ]
   updateDscpRules()

def noDscp( mode, args ):
   gv.radiusConfig.dscpValue = gv.radiusConfig.dscpValueDefault
   updateDscpRules()

def setRadiusServerTls( mode, args ):
   hostname = args[ '<HOSTNAME>' ]
   vrf = args.get( 'VRF', DEFAULT_VRF )
   port = args.get( '<AUTHPORT>', RadiusGroup.defaultPort )
   acctPort = args.get( '<ACCTPORT>', RadiusGroup.defaultAcctPort )
   timeout = args.get( '<TIMEOUT>' )
   retries = args.get( '<RETRIES>' )
   key = args.get( '<KEY>' )
   secretProfileName = args.get( '<SECRET_PROFILE>' )
   tlsEnabled = args.get( 'tls', None )
   sslProfile = args.get( '<PROFILENAME>', '' )
   tlsPort = args.get( '<TLSPORT>', RadiusGroup.defaultTlsPort )
   if not tlsEnabled:
      tlsPort = 0
   else:
      acctPort = 0
      port = tlsPort
   HostnameCli.resolveHostname( mode, hostname, doWarn=True )
   assert vrf != ''
   if timeout is None:
      timeoutVal = gv.radiusConfig.defaultTimeout
   else:
      timeoutVal = timeout
   if retries is None:
      retriesVal = retries or gv.radiusConfig.defaultRetries
   else:
      retriesVal = retries
   if not checkKeySize( mode, key ):
      return
   host = radiusHost( mode, hostname, vrf, port, acctPort,
                      tlsEnabled, create=True )
   host.useKey = ( key is not None or secretProfileName is not None )
   host.secretProfileName = secretProfileName or ''
   host.key = key or ReversibleSecretCli.getDefaultSecret()
   host.useTimeout = ( timeout is not None )
   host.timeout = timeoutVal
   host.useRetries = ( retries is not None )
   host.retries = retriesVal
   host.sslProfile = sslProfile
   if mode.session_.interactive_:
      vrfString = f" in vrf {vrf}" if vrf != DEFAULT_VRF else ""
      if not tlsEnabled:
         mode.addMessage( f"RADIUS host {hostname} with auth-port {port} "
                          f"and acct-port {acctPort} created{vrfString}" )
      else:
         profileName = sslProfile if sslProfile else None
         mode.addMessage( f"RADIUS host {hostname} with ssl-profile {profileName} "
                          f"and TLS port {tlsPort} created{vrfString}" )
   updateDscpRules()

def noRadiusServerTls( mode, args ):
   hostname = args.get( '<HOSTNAME>' )
   hosts = gv.radiusConfig.host

   if hostname:
      vrf = args.get( 'VRF', DEFAULT_VRF )
      acctPort = args.get( '<ACCTPORT>', RadiusGroup.defaultAcctPort )
      tlsEnabled = args.get( 'tls', None )
      if not tlsEnabled:
         port = args.get( '<AUTHPORT>', RadiusGroup.defaultPort )
         proto = hostProtocol.protoRadius
      else:
         port = args.get( '<TLSPORT>', RadiusGroup.defaultTlsPort )
         acctPort = 0
         proto = hostProtocol.protoRadsec

      spec = Tac.Value( "Aaa::HostSpec", hostname=hostname, port=port,
                        acctPort=acctPort, vrf=vrf,
                        protocol=proto )
      if spec in hosts:
         del hosts[ spec ]
      else:
         if mode.session_.interactive_:
            vrfString = f" in vrf {vrf}" if vrf != DEFAULT_VRF else ""
            if not tlsEnabled:
               warningMessage = ( f"RADIUS host {hostname} with auth-port {port} and"
                                  f" acct-port {acctPort} not found{vrfString}" )
            else:
               warningMessage = ( f"RADIUS host {hostname} with TLS port {port} "
                                  f"not found{vrfString}" )
            mode.addWarning( warningMessage )
   else:
      # Delete all hosts since no hostname was specified
      hosts.clear()

   updateDscpRules()

#-------------------------------------------------------------------------------
# "[no] ip radius [ VRF ] source-interface <interface-name>"
# in config mode
#-------------------------------------------------------------------------------
def setRadiusSrcIntf( mode, args ):
   intf = args[ "INTF" ]
   vrf = args.get( "VRF" )
   if not vrf or vrf == DEFAULT_VRF_OLD:
      vrf = DEFAULT_VRF
   gv.radiusConfig.srcIntfName[ vrf ] = intf.name

def noRadiusSrcIntf( mode, args ):
   vrf = args.get( "VRF" )
   if not vrf or vrf == DEFAULT_VRF_OLD:
      vrf = DEFAULT_VRF
   del gv.radiusConfig.srcIntfName[ vrf ]

#-------------------------------------------------------------------------------
# Have the Cli Agent mount all needed state from sysdb
#-------------------------------------------------------------------------------
def Plugin( entityManager ):
   gv.aaaConfig = LazyMount.mount( entityManager, "security/aaa/config",
                                   "Aaa::Config", "r" )
   gv.radiusConfig = ConfigMount.mount( entityManager, "security/aaa/radius/config",
                                        "Radius::Config", "w" )
   gv.radiusCounterConfig = LazyMount.mount( entityManager,
                                             "security/aaa/radius/counterConfig",
                                             "AaaPlugin::CounterConfig", "w" )
   gv.radiusStatus = LazyMount.mount( entityManager,
                                      Cell.path( "security/aaa/radius/status" ),
                                      "Radius::Status", "r" )
   gv.radiusInputStatus = LazyMount.mount( entityManager,
                                    Cell.path( "security/aaa/radius/input/status" ),
                                    "Tac::Dir", "ri" )
   gv.dscpConfig = ConfigMount.mount( entityManager, "mgmt/dscp/config",
                                      "Mgmt::Dscp::Config", "w" )
   gv.identityDot1xStatus = LazyMount.mount(
      entityManager, 'identity/dot1x/status', 'Identity::Dot1x::Status', 'r' )
