# Copyright (c) 2009-2010, 2013-2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-merging-isinstance
# pylint: disable=consider-using-f-string

from freeradiusclient import * # pylint: disable=wildcard-import
import socket
import threading

import AaaPluginLib
from AaaPluginLib import TR_ERROR, TR_WARN, TR_AUTHEN, TR_INFO
from Arnet.NsLib import DEFAULT_NS
from BothTrace import traceX as bt
from BothTrace import Var as bv
from collections import namedtuple
from math import ceil
from Tracing import traceX
from TypeFuture import TacLazyType
import Tac

# namedtuple to hold all Radius authentication info
AuthenStatus = namedtuple( "AuthenStatus",
      "status, failText, privLevel, roles, rules, classAttr" )

# maximum RADIUS key size
MAX_KEY_SIZE = 128
MAX_NAS_ID_SIZE = 253

# Arista's SMI Network Management Private Enterprise Code
aristaVendorId = 30065

# Default port for COA messages.
DEFAULT_COA_PORT = 3799

# Default port for Radius/TLS
DEFAULT_RADSEC_PORT = 2083

# Internal UDP port for radsecproxy <-> agent communication
radsecproxyInternalPort = TacLazyType(
   "Arnet::PrivateUdpPorts" ).radsecproxyInternalPort

radiusNasIdType = Tac.Type( "Radius::NasIdType" )
#
# Below are index and name of Arista-AVPair in the Radius dictionary (see also the
# freeradius-client package). In order for a Radius server to understand those they
# need to be defined in the dictionary file.
#
# - Short-term solution: we should put both of these definitions in dictionary files
#   and publish them on the web, so users can simply include dictionary.arista in
#   their default dictionary.
# - Long-term solution: Arista's dictionary should ship along with the other default
#   dictionaries in FreeRADIUS and Cisco ACS.
#
# ### FreeRADIUS server ###
# VENDOR Arista 30065
# BEGIN-VENDOR Arista
#
# ATTRIBUTE       Arista-AVPair                   1       string
# ATTRIBUTE       Arista-User-Priv-Level          2       integer
# ATTRIBUTE       Arista-User-Role                3       string
# ATTRIBUTE       Arista-CVP-Role                 4       string
# ATTRIBUTE       Arista-Command                  5       string
# ATTRIBUTE       Arista-BlockMac                7       string
# ATTRIBUTE       Arista-UnblockMac              8       string
# END-VENDOR Arista
#
# ### Cisco ACS server ###
# [User Defined Vendor]
# Name=Arista
# IETF Code=30065
# VSA 1=Arista-AVPair
# VSA 2=Arista-User-Priv-Level
# VSA 3=Arista-User-Role
# VSA 4=Arista-CVP-Role
# VSA 5=Arista-Command
# VSA 6=Arista-WebAuth
# VSA 7=Arista-BlockMac
# VSA 8=Arista-UnblockMac
# VSA 10=Arista-Captive-Portal
# VSA 11=Arista-Segment-Id
# [Arista-AVPair]
# Type=STRING
# Profile=IN OUT
#
# [Arista-User-Privlevel]
# Type=INTEGER
# Profile=IN OUT
#
# [Arista-User-Role]
# Type=STRING
# Profile=IN OUT
#
# [Arista-CVP-Role]
# Type=STRING
# Profile=IN OUT
#
# [Arista-Command]
# Type=STRING
# Profile=IN OUT
#
# [Arista-WebAuth]
# Type=INTEGER
#
# [Arista-BlockMac]
# Type=STRING
# Profile=IN OUT
#
# [Arista-UnblockMac]
# Type=STRING
# Profile=IN OUT
#
# [Arista-Captive-Portal]
# Type=STRING
# Profile=IN OUT
#
# [Arista-Segment-Id]
# Type=STRING
# Profile=IN OUT

aristaAvpAttrIndex = 1
aristaAvpAttrName = "Arista-AVPair"
aristaUserPrivLevelIndex = 2
aristaUserPrivLevelName = "Arista-User-Priv-Level"
aristaUserRoleIndex = 3
aristaUserRoleName = "Arista-User-Role"
aristaCommandIndex = 5
aristaCommandName = "Arista-Command"
aristaBlockMacIndex = 7
aristaBlockMacName = "Arista-BlockMac"
aristaUnblockMacIndex = 8
aristaUnblockMacName = "Arista-UnblockMac"


# Cisco attributes, in case the server is not configured for Arista
ciscoVendorId = 9
ciscoAvpAttrIndex = 1
ciscoAvpAttrName = "cisco-avpair"

# Default attributes
defaultVendorId = 0
defaultClassAttrIndex = 25
defaultClassAttrName = "Class"

# Syntax for priv-lvl, shared by both Arista and Cisco configurations
privLvlPrefix = "shell:priv-lvl="

# Syntax for roles, shared by both Arista and Cisco NX-OS
rolesPrefix = "shell:roles="

# Syntax for commands
cmdPrefix = "shell:cmd="

class BadConfigError( Exception ):
   pass

class AuthenticationError( Exception ):
   pass

class AuthorizationError( Exception ):
   pass

class AccountingError( Exception ):
   pass

class BadServerError( Exception ):
   pass

class AcctReq:
   def __init__( self, rh, username, acctSessionId, privLevel=None ):
      self.req = None
      self.rh = rh
      self._addVar( PW_USER_NAME, "User Name",  username ) 
      if privLevel is not None:
         self._serviceTypeIs( privLevel )
      if not ( isinstance( acctSessionId, str ) or \
            isinstance( acctSessionId, int ) ):
         raise AccountingError( "session id is not int or string type" )
      self._addVar( PW_ACCT_SESSION_ID, "Accounting Session ID", 
                    str( acctSessionId ) )

   def _addVar( self, key, keyString, value, vendor=0 ):
      if isinstance( value, bytes ):
         self.req = rc_avpair_addvar_byte( self.rh, self.req, key,
                                           value, vendor )
      elif isinstance( value, str ):
         self.req = rc_avpair_addvar_string( self.rh, self.req, key, value, vendor )
      else:
         self.req = rc_avpair_addvar_int( self.rh, self.req, key, value, vendor )
      if not self.req:
         raise AccountingError( "failed to add %s %s" %  \
                                ( keyString , str( value ) ) )

   def _serviceTypeIs( self, privLevel ):
      if int( privLevel ) <= 1:
         service = PW_NAS_PROMPT
      else:
         service = PW_ADMINISTRATIVE
      self._addVar( PW_SERVICE_TYPE, "SERVICE_TYPE", service )

   def acctStatusTypeIs( self, acctStatusType ):
      self._addVar( PW_ACCT_STATUS_TYPE, "Accounting Status Type", acctStatusType )

   def portIs( self, tty ):
      self._addVar( PW_NAS_PORT_ID, "Port", tty )

   def remoteHostIs( self, remoteHost ):
      self._addVar( PW_CALLING_STATION_ID, "Remote Host", remoteHost )

   def authenMethodIs( self, authenMethod ):
      knownMethods = [ PW_LOCAL, PW_RADIUS, PW_REMOTE ]
      assert authenMethod in knownMethods
      self._addVar( PW_ACCT_AUTHENTIC, "Authentication Method", authenMethod )

   def elapsedTimeIs( self, time ):
      self._addVar( PW_ACCT_SESSION_TIME, "Accounting Session Time",
            int( ceil( time ) ) )

   def commandIs( self, tokens ):
      vendorAttrValue = cmdPrefix + ' '.join( tokens )
      self._addVar( aristaAvpAttrIndex, aristaAvpAttrName, vendorAttrValue,
                    aristaVendorId )

   def sessionId( self ):
      sessionIdAvp = rc_avpair_get( self.req, PW_ACCT_SESSION_ID, 0 )
      sessionIdStr = sessionIdAvp.strvalue
      return int( sessionIdStr )

class Session( AaaPluginLib.Session ):
   """A session with a RADIUS server, or potentially multiple servers in a
   failover arrangement. Since RADIUS is using UDP, the main purpose is to
   be able to share the client handle instead of the request, especially
   servers status in a failover configuration.

   The usage model for authentication is to create a Session, add one or more
   servers using addServer, and call sendAuthenReq. If adding servers fail,
   it raises a BadServerError. If the authentication itself fails, it
   raises an AuthenticationError exception.
   """
   def __init__( self, hostgroup, timeout, retries, deadtime=0 ):
      # initialize the client handle
      AaaPluginLib.Session.__init__( self, hostgroup )
      self.handle_ = None
      self.taskId = 0
      rh = rc_config_init( rc_new( ) )
      rc_add_config(rh, "auth_order", "radius", "config", 0)
      rc_add_config_int(rh, "radius_timeout", int(timeout))
      rc_add_config_int(rh, "radius_retries", retries)
      rc_add_config_int(rh, "radius_deadtime", int(deadtime))
      rc_add_config(rh, "dictionary", "/etc/radiusclient/dictionary",
                    "config", 0)
      rc_add_config(rh, "seqfile", "/var/run/radius_seq_file",
                    "config", 0)

      if rc_read_dictionary( rh, rc_conf_str( rh, "dictionary" )):
         raise BadConfigError( "failed to initialize radius dictionary" )
      self.handle_ = rh
      self.statusCallback = []
      # We create a weak reference to the counterHook function
      # to avoid circular reference (bug 13571). Also hold a
      # reference to the weak reference itself so it won't go away.
      self.counterHookMethod_ = Tac.WeakBoundMethod( self._counterHook )
      rc_add_python_counterhook( rh, self.counterHookMethod_ )

   def _counterHook( self, serverIndex, counterType, delta ):
      callback = self.statusCallback[ serverIndex ]
      if callback:
         callback( counterType, delta )
      return 0

   def close( self ):
      if self.handle_:
         rc_destroy( self.handle_ )
         self.handle_ = None
         self.counterHookMethod_ = None

   def __del__( self ):
      # Just in case close() wasn't explicity called as expected.
      self.close( )

   def addServer( self, host, authPort, acctPort, tlsPort, secret, timeout,
                  retries, ns=DEFAULT_NS, srcIpAddr=None, srcIp6Addr=None,
                  counterCallback=None ):
      if not srcIpAddr:
         srcIpAddr = "*"        # INADDR_ANY
      if not srcIp6Addr:
         srcIp6Addr = "*"       # in6addr_any

      def _add( host, serverType, port, tls=False ):
         # try to resolve the hostname so we have better control on error reporting
         # otherwise freeradiusclient will do it at authentication time
         try:
            socket.getaddrinfo( host, port )
         except socket.gaierror as e:
            if counterCallback:
               bt( TR_ERROR, "cannot resolve host:", bv( str( e ) ) )
               # 5 is "authnHostUnresolvable", just double check
               attrName = CounterCallback.attrNameFromIndex( 5 )
               assert attrName == "authnHostUnresolvable"
               counterCallback( 5, 1 )
            # pylint: disable-next=raise-missing-from
            raise BadServerError( "Failed to resolve hostname" )
         r = rc_add_config( self.handle_, serverType,
                            "%s|%s|%s|%d|%s|%s|%s|%s|%s" % 
                            ( host, ns, port, tls, timeout, retries, srcIpAddr,
                              srcIp6Addr, secret ),
                            "config", 0 )
         if r != 0:
            # The only way I know this can happen is when the host is specified
            # as a hostname that fails to resolve.
            raise BadServerError( "Failed to add %s host=%s port=%d" 
                                  % ( serverType, host, port ) )
      if tlsPort:
         _add( host, "authserver", tlsPort, True )
         _add( host, "acctserver", tlsPort, True )
      else:
         _add( host, "authserver", authPort )
         _add( host, "acctserver", acctPort )
      self.statusCallback.append( counterCallback )

   def sendAuthenReq( self, username, password, privLevel, remoteHost=None,
                      tty=None, nasId=None ):
      """Send authentication request; it returns status, privilege level and a
      message string"""
      assert username is not None
      assert password is not None
      assert privLevel is not None
      assert privLevel >= 0 and privLevel <= 15 # pylint: disable=chained-comparison
      rh = self.handle_
      req = None
      req = rc_avpair_addvar_string(rh, req, PW_USER_NAME, username, 0)
      if not req:
         raise AuthenticationError( "failed to add User Name", username )

      req = rc_avpair_addvar_string(rh, req, PW_USER_PASSWORD, password, 0)
      if not req:
         raise AuthenticationError( "failed to add User Password ********" )

      if nasId:
         req = rc_avpair_addvar_string( rh, req, PW_NAS_IDENTIFIER, nasId, 0 )
         if not req:
            raise AuthenticationError( "failed to add NAS Identifier", tty )

      if tty:
         req = rc_avpair_addvar_string( rh, req, PW_NAS_PORT_ID, tty, 0 )
         if not req:
            raise AuthenticationError( "failed to add Port", tty )

      if remoteHost:
         req = rc_avpair_addvar_string(rh, req, PW_CALLING_STATION_ID, remoteHost, 0)
         if not req:
            raise AuthenticationError( "failed to add Remote Host", remoteHost )

      if privLevel <= 1:
         service = PW_NAS_PROMPT
      else:
         service = PW_ADMINISTRATIVE

      req = rc_avpair_addvar_int(rh, req, PW_SERVICE_TYPE, service, 0)
      if not req:
         raise AuthenticationError( "failed to add SERVICE_TYPE" )

      # send the request
      retcode, resp, msg = rc_auth( self.handle_, 0, req )
      rc_avpair_free( req )

      shellPrivLevel = None
      roles = []
      rules = []
      classAttr = None

      def _getPrivLevel( svalue ):
         try:
            val = int( svalue[ len( privLvlPrefix ): ] )
            return val
         except ValueError:
            return None

      avp = resp
      while avp:
         try:
            vendorId = VENDOR( avp.attribute )
            attribId = ATTRID( avp.attribute )
            traceX( TR_INFO, "AVPair: vendor", vendorId, "attribute", attribId )
         except TypeError:
            # pylint: disable-next=raise-missing-from
            raise AuthenticationError( "bad avp attribute value", avp.attribute )

         if vendorId == aristaVendorId:
            if attribId == aristaAvpAttrIndex:
               # The generic Arista-AVPair
               strvalue = avp.strvalue
               if strvalue.startswith( privLvlPrefix ):
                  shellPrivLevel = _getPrivLevel( strvalue )
               elif strvalue.startswith( rolesPrefix ):
                  roleName = strvalue[ len( rolesPrefix ): ]
                  if roleName:
                     roles.append( roleName )

            elif attribId == aristaUserPrivLevelIndex:
               val = rc_avpair_get_intval( avp )
               if val >= 0 and val <= 15: # pylint: disable=chained-comparison
                  shellPrivLevel = val
            elif attribId == aristaUserRoleIndex and avp.strvalue:
               roles.append( avp.strvalue )
            elif attribId == aristaCommandIndex and avp.strvalue:
               rules.append( avp.strvalue )
         elif ( vendorId == ciscoVendorId and attribId == ciscoAvpAttrIndex and
                shellPrivLevel is None ):
            # only accept Cisco avpair if there is no Arista avpair
            strvalue = avp.strvalue
            if strvalue.startswith( privLvlPrefix ):
               shellPrivLevel = _getPrivLevel( strvalue )
         elif vendorId == defaultVendorId and attribId == defaultClassAttrIndex:
            # get class as bytes
            classAttr = rc_avpair_get_bytevaule( avp )

         avp = avp.next

      traceX( TR_AUTHEN, "privilege", shellPrivLevel,
              "roles", roles, "classAttr", classAttr,
              "status", retcode )
      rc_avpair_free( resp )

      return AuthenStatus( retcode, msg, shellPrivLevel, roles, rules, classAttr )

   def createAcctReq( self, username, privLevel=None, acctSessionId=None,
                      acctStatusType=PW_STATUS_ALIVE ):
      rh = self.handle_
      if acctSessionId is None:
         self.taskId += 1
         acctSessionId = self.taskId
      acctReq = AcctReq( rh, username, acctSessionId, privLevel=privLevel )
      acctReq.acctStatusTypeIs( acctStatusType )
      return acctReq

   def sendAcctReq( self, acctReq ):
      retcode = rc_acct( acctReq.rh, 0, acctReq.req )
      rc_avpair_free( acctReq.req )
      return retcode

radiusCounterAttrs = ( "authnMessagesSent",
                       "authnMessagesReceived",
                       "authnAcceptsReceived",
                       "authnRejectsReceived",
                       "authnChallengesReceived",
                       "authnHostUnresolvable",
                       "authnRequestsTimeout",
                       "authnRequestsRetransmitted",
                       "authnBadResponses",
                       "authnConnectionErrors",
                       "coaRequestsReceived",
                       "dmRequestsReceived",
                       "coaAckResponses",
                       "dmAckResponses",
                       "coaNakResponses",
                       "dmNakResponses",
                       "acctStartsSent",
                       "acctInterimUpdatesSent",
                       "acctStopsSent" )

class CounterCallback:
   mutex = threading.Lock()

   @staticmethod
   def attrNameFromIndex( idx ):
      # These match up with the freeradiusclient constants starting with RC_COUNTER_
      return radiusCounterAttrs[ idx ]

   def __init__( self, status, hostspec ):
      self.status = status
      self.hostspec = hostspec

   def __call__( self, counterType, delta ):
      attrName = CounterCallback.attrNameFromIndex( counterType )
      traceLevel = TR_ERROR if counterType in ( 3, 5, 6, 8 ) else TR_INFO
      bt( traceLevel, "RADIUS counter", bv( attrName ), "host",
          bv( self.hostspec.stringValue() ) )
      with self.mutex:
         counterToUse = self.status.counter
         if self.hostspec not in counterToUse:
            bt( TR_WARN, "RADIUS counters missing entry for",
                bv( self.hostspec.stringValue() ) )
            return # Operator probably removed radius host in the
                   # middle of the authentication request.
         c = Tac.nonConst( counterToUse[ self.hostspec ] )
         old = c.__getattribute__( attrName )
         new = old + delta
         c.__setattr__( attrName, new )
         counterToUse[ self.hostspec ] = c

def getHostCounters( host, radiusStatus, radiusInputStatus,
                     useCheckpoint=False ):
   # returns aggregated counters for specified host
   counter = Tac.Value( "Radius::Counters" )
   for attr in radiusCounterAttrs:
      csum = 0
      for s in [ radiusStatus ] + list( radiusInputStatus.values() ):
         h = s.counter.get( host )
         if h:
            csum += getattr( h, attr, 0 )
      # It's possible for spec to be not present in the input status but present
      # in checkpoint. For examples server deleted from the group config after
      # clearing counters. It's important to not take into account the checkpoint
      # values in such cases to prevent negative values. BUG605873
      if useCheckpoint and csum:
         h = radiusStatus.checkpoint.get( host )
         if h:
            csum -= getattr( h, attr, 0 )
      setattr( counter, attr, csum )

   return counter
