# Copyright (c) 2008-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AaaPluginLib
from AaaPluginLib import TR_ERROR, TR_AUTHEN, TR_AUTHZ, \
      TR_ACCT, TR_INFO, hostProtocol
from BothTrace import traceX as bt
from BothTrace import Var as bv
import Tac
from Tracing import traceX
import libtacplus
import re

# Each Tacacs attribute value pair can be of max 255 characters as per
# the protocol. Example from the tacacs server logs:
# Tue Jun 23 17:52:07 2015 [25223]: arg[1]: size=8
# Tue Jun 23 17:52:07 2015 [25223]: cmd=echo
# Tue Jun 23 17:52:07 2015 [25223]: arg[2]: size=11
# Tue Jun 23 17:52:07 2015 [25223]: cmd-arg=foo
TAC_AV_MAX_LENGTH = 255
TAC_AV_MAX_LENGTH_CMD = TAC_AV_MAX_LENGTH - len( "cmd" ) - 1
TAC_AV_MAX_LENGTH_CMD_ARG = TAC_AV_MAX_LENGTH - len( "cmd-arg" ) - 1
TAC_AV_MAX_LENGTH_ACCT_CMD = TAC_AV_MAX_LENGTH - len( "cmd" + " <cr>" ) - 1

class _Attr:
   def __init__( self, name, value, optional=False ):
      self.name = name
      self.value = value
      self.optional = optional

class _Server:
   def __init__( self, host, port, vrf, ns, key, timeout, srcIp="", srcIp6="",
                 singleConn=False, statusCallback=None ):
      self.host_ = host
      self.port_ = port
      self.vrf_ = vrf
      self.ns_ = ns
      self.srcIp_ = srcIp
      self.srcIp6_ = srcIp6
      self.key_ = key
      self.timeout_ = timeout
      self.flags_ = 0
      if singleConn:
         self.flags_ |= libtacplus.TAC_SRVR_SINGLE_CONNECT
      self.statusCallback = statusCallback

class Req:
   def __init__( self, authenType, authenService ):
      self.authenType_ = authenType
      self.authenService_ = authenService
      self.user_ = None
      self.port_ = None
      self.privLevel_ = None
      self.remoteAddr_ = None
      self.attrs_ = {}
      self.attrIdx_ = 0
      self.created_ = False

   def setCreated( self ):
      """Called when the corresponding object is created in the tac_handle,
      at which point some of the properties can no longer be changed."""
      self.created_ = True

   def authenTypeIs( self, authenType ):
      assert not self.created_
      types = [ libtacplus.TAC_AUTHEN_TYPE_ASCII ]
      assert authenType in types
      self.authenType_ = authenType

   def authenServiceIs( self, authenService ):
      assert not self.created_
      services = [ libtacplus.TAC_AUTHEN_SVC_NONE,
                   libtacplus.TAC_AUTHEN_SVC_LOGIN,
                   libtacplus.TAC_AUTHEN_SVC_ENABLE ]
      assert authenService in services
      self.authenService_ = authenService

   def userIs( self, user ):
      self.user_ = user

   def portIs( self, port ):
      """Port is a funny name -- in TACACS+-speak a port identifies where the
      user is connecting from, eg. the tty."""
      self.port_ = port

   def remoteAddrIs( self, addr ):
      self.remoteAddr_ = addr

   def privLevelIs( self, privLevel ):
      self.privLevel_ = privLevel

   def setAttr( self, name, value, optional=False, index=None ):
      if index is None:
         index = self.attrIdx_
         self.attrIdx_ += 1
      self.attrs_[ index ] = _Attr( name, value, optional )

   def attrs( self ):
      return len( self.attrs_ )

   def attr( self, index ):
      return self.attrs_[ index ]

   def findAttr( self, name ):
      r = []
      for i in sorted( self.attrs_.keys() ):
         attr = self.attrs_[ i ]
         if name == attr.name:
            r.append( i )
      if len( r ) == 0:
         return None
      elif len( r ) == 1:
         return r[ 0 ]
      else:
         return r

   def serviceIs( self, service ):
      # Allowed values according to the RFC:
      assert service in ( "shell", "slip", "ppp", "arap", "tty-daemon",
                          "connection", "system", "firewall" )
      index = self.findAttr( "service" )
      self.setAttr( "service", service, index=index )

class AuthenReq ( Req ):
   def __init__( self, service, action, type ): # pylint: disable=redefined-builtin
      Req.__init__( self, type, service )
      self.action_ = action
      self.message_ = None

   def messageIs( self, msg ):
      self.message_ = msg 

class AuthzReq( Req ):
   def __init__( self, authenMethod, authenType, authenService ):
      Req.__init__( self, authenType, authenService )
      self.authenMethod_ = authenMethod
      self.commandTokens = None

   def authenMethodIs( self, authenMethod ):
      assert not self.created_
      methods = [ libtacplus.TAC_AUTHEN_METH_NOT_SET,
                  libtacplus.TAC_AUTHEN_METH_NONE,
                  libtacplus.TAC_AUTHEN_METH_KRB5,
                  libtacplus.TAC_AUTHEN_METH_LINE,
                  libtacplus.TAC_AUTHEN_METH_ENABLE,
                  libtacplus.TAC_AUTHEN_METH_LOCAL,
                  libtacplus.TAC_AUTHEN_METH_TACACSPLUS,
                  libtacplus.TAC_AUTHEN_METH_RCMD ]
      assert authenMethod in methods
      self.authenMethod_ = authenMethod

   def commandIs( self, tokens ):
      if len( tokens ) == 0:
         self.setAttr( "cmd", "", optional=True )
      else:
         self.setAttr( "cmd", tokens[ 0 ][ : TAC_AV_MAX_LENGTH_CMD ] )
         for t in tokens[ 1 : ]:
            self.setAttr( "cmd-arg", t[ : TAC_AV_MAX_LENGTH_CMD_ARG ] )
         self.setAttr( "cmd-arg", "<cr>" )

class AcctReq( Req ):
   def __init__( self, authenMethod, authenType, authenService,
                 acctAction ):
      Req.__init__( self, authenType, authenService )
      self.authenMethod_ = authenMethod
      self.acctAction_ = acctAction

   def authenMethodIs( self, authenMethod ):
      assert not self.created_
      methods = [ libtacplus.TAC_AUTHEN_METH_NOT_SET,
                  libtacplus.TAC_AUTHEN_METH_NONE,
                  libtacplus.TAC_AUTHEN_METH_KRB5,
                  libtacplus.TAC_AUTHEN_METH_LINE,
                  libtacplus.TAC_AUTHEN_METH_ENABLE,
                  libtacplus.TAC_AUTHEN_METH_LOCAL,
                  libtacplus.TAC_AUTHEN_METH_TACACSPLUS,
                  libtacplus.TAC_AUTHEN_METH_RCMD ]
      assert authenMethod in methods
      self.authenMethod_ = authenMethod
      
   def startTimeIs( self, time=None ):
      # start_time should be an integer - number of seconds from epoch
      self.setAttr( "start_time", int( time ) if time else int( Tac.utcNow() ) )

   def elapsedTimeIs( self, time ):
      # elapsed_time should be an integer - number of seconds
      from math import ceil # pylint: disable=import-outside-toplevel
      self.setAttr( "elapsed_time", int( ceil( time ) ) )
      
   def timeZoneIs( self, timezone=None ):
      if not timezone:
         import time # pylint: disable=import-outside-toplevel
         isdst = time.localtime().tm_isdst
         if isdst:
            timezone = time.tzname[ 1 ]
         else:
            timezone = time.tzname[ 0 ]
      self.setAttr( "timezone", timezone )

   def commandIs( self, cmd ):
      # set 'cmd' as-is
      if cmd:
         self.setAttr( "cmd", cmd[ : TAC_AV_MAX_LENGTH_ACCT_CMD ] + ' <cr>' )
      else:
         self.setAttr( "cmd", "", optional=True )

   def privLevelIs( self, privLevel ):
      # in IOS, the priv-lvl is also sent as an AV-pair for accounting
      self.privLevel_ = privLevel
      self.setAttr( "priv-lvl", privLevel )

   def taskIdIs( self, taskId ):
      self.setAttr( "task_id", taskId )
      
   def taskId( self ):
      index = self.findAttr( "task_id" )
      if index is None:
         return 0
      return self.attr( index ).value
      
class AuthenticationError( Exception ):
   pass

class AuthorizationError( Exception ):
   pass

class AccountingError( Exception ):
   pass

class BadServerException( Exception ):
   pass

class Session( AaaPluginLib.Session ):
   """A session with a TACACS+ server, or potentially multiple servers in a
   failover arrangement.

   The usage model for authentication is to create a Session, add one or more
   servers using addServer, create an authentication request using
   createAuthenReq, configuring the request by setting the user, etc, then
   calling sendAuthenReq.  Check the return value from sendAuthenReq, then call
   continueAuthen one or more times, providing information as requested by the
   server until the authentication negotiation completes.

   The usage model for authorization is to create a Session, add one or more
   servers using addServer, create an authorization request using
   createAuthzReq, configuring the request by setting the user, command, etc,
   then calling sendAuthzReq and checking the return value.

   When a failure occurs during the initial communication with a server, ie.
   before a response has been successfully received by this class, the code
   herein will retry up to once per server.  Any failures that happen later
   during an authentication session will result in an AuthenticationError
   being raised.  Authorization is treated much the same way, except an
   AuthorizationError is raised."""
   def __init__( self, hostgroup ):
      bt( TR_AUTHEN, "creating tacacs+ session" )
      AaaPluginLib.Session.__init__( self, hostgroup )
      self.handle_ = libtacplus.tac_open()
      # We create a weak reference to the counterHook function
      # to avoid circular reference (bug 13571). Also hold a
      # reference to the weak reference itself so it won't go away.
      self.counterHookMethod_ = Tac.WeakBoundMethod( self._counterHook )
      libtacplus.tac_add_python_counterhook(
         self.handle_, self.counterHookMethod_ )
      self.servers_ = []
      self.currentAuthenReq_ = None
      self.currentAcctReq_ = None
      self.continueExpected_ = False
      self.taskId = 0
      self.maxUsernameLength_ = 255

   def _counterHook( self, serverIndex, counterType, delta ):
      callback = self.servers_[ serverIndex ].statusCallback
      if callback:
         callback( counterType, delta )
      return 0

   def close( self ):
      if self.handle_:
         bt( TR_AUTHEN, "closing tacacs+ session" )
         libtacplus.tac_close( self.handle_ )
         self.handle_ = None
         self.counterHookMethod_ = None

   def closeConnection( self ):
      if self.handle_:
         bt( TR_ERROR, "closing tacacs+ connection" )
         libtacplus.tac_close_connection( self.handle_ )

   def __del__( self ):
      # Just in case close() wasn't explicity called as expected.
      self.close()

   def maxUsernameLengthIs( self, length ):
      self.maxUsernameLength_ = length

   def addServer( self, *args, **kwargs ):
      s = _Server( *args, **kwargs )
      # We add the server now, as tac_add_server() might throw an error
      # through our counter hook that relies on this.
      assert s.ns_ and s.ns_ != ''
      traceX( TR_INFO, "addServer:", s.host_, ":", s.port_, "ns:", s.ns_,
              "vrf:", s.vrf_, "srcIP:", s.srcIp_, "srcIp6:", s.srcIp6_,
              "key:", s.key_, "timeout:", s.timeout_, "flags:", s.flags_ )
      r = libtacplus.tac_add_server( self.handle_, s.host_, s.port_, s.ns_,
                                     s.srcIp_, s.srcIp6_, s.key_,
                                     s.timeout_, s.flags_ )
      if r != 0:
         # The only way I know this can happen is when the host is specified
         # as a hostname that fails to resolve, so increment the DNS error.
         if s.statusCallback:
            s.statusCallback( "dnsErrors", 1 )
         errMsg = libtacplus.tac_strerror( self.handle_ )
         # We failed to resolve the hostname. This might happen if we have a
         # temporary network glitch or the system isn't fully initialized.
         # Let's set a shorter timeout.
         self.expireBy_ = Tac.now() + 300
         raise BadServerException( errMsg )
      self.servers_.append( s )

   def getCurServer( self ):
      serverIndex = libtacplus.tac_cur_server( self.handle_ )
      bt( TR_INFO, "cur server", serverIndex )
      server = self.servers_[ serverIndex ]
      return Tac.Value( "Aaa::HostSpec",
                        hostname=server.host_,
                        port=server.port_,
                        acctPort=0,
                        vrf=server.vrf_,
                        protocol=hostProtocol.protoTacacs )

   def setCurServer( self, hostSpec ):
      if self.getCurServer() == hostSpec:
         return
      for i, s in enumerate( self.servers_ ):
         if ( s.host_ == hostSpec.hostname and
              s.port_ == hostSpec.port and
              s.vrf_ == hostSpec.vrf ):
            bt( TR_INFO, "set cur server", i )
            libtacplus.tac_set_cur_server( self.handle_, i )
            break

   def createAuthenReq( self, service=libtacplus.TAC_AUTHEN_SVC_LOGIN,
                        action=libtacplus.TAC_AUTHEN_LOGIN,
                        # pylint: disable-next=redefined-builtin
                        type=libtacplus.TAC_AUTHEN_TYPE_ASCII ):
      req = AuthenReq( service=service, action=action, type=type )
      self.currentAuthenReq_ = req
      return req

   def sendAuthenReq( self, req, isContinue=False ):
      assert req.__class__ == AuthenReq
      assert req is self.currentAuthenReq_
      # I will attempt to perform the initial step at least once per server.  I
      # rely on libtacplus to advance to the next server when an error occurs.
      attempts = 0
      while True:
         attempts += 1
         req.setCreated()
         r = libtacplus.tac_create_authen( self.handle_, req.action_,
                                           req.authenType_, req.authenService_,
                                           isContinue )
         if r != 0:
            err = libtacplus.tac_strerror( self.handle_ )
            raise AuthenticationError( err )

         try:
            return self._doAuthenReq()
         except AuthenticationError as e:
            if isContinue or attempts >= len( self.servers_ ):
               raise
            bt( TR_ERROR, "sendAuthenReq: retrying after:", bv( str( e ) ) )

   def continueAuthenReq( self ):
      if not self.continueExpected_:
         raise AuthenticationError( "Cannot continue authentication session" )
      return self.sendAuthenReq( self.currentAuthenReq_, isContinue=True )

   def _doAuthenReq( self ):
      self._syncReq( self.currentAuthenReq_, AuthenticationError )
      r = libtacplus.tac_send_authen( self.handle_ )
      if r == -1:
         err = libtacplus.tac_strerror( self.handle_ )
         if not err:
            err = "Unknown error sending request"
         bt( TR_ERROR, "tac_send_authen failed:", bv( err ) )
         raise AuthenticationError( err )

      status = libtacplus.TAC_AUTHEN_STATUS( r )
      noecho = libtacplus.TAC_AUTHEN_NOECHO( r )
      traceX( TR_AUTHEN, "tac_send_authen returned", r, "status:", status,
              "noecho:", noecho )
      self.continueExpected_ = status in (
         libtacplus.TAC_AUTHEN_STATUS_GETDATA,
         libtacplus.TAC_AUTHEN_STATUS_GETUSER,
         libtacplus.TAC_AUTHEN_STATUS_GETPASS )
      return ( status, noecho, libtacplus.tac_get_msg( self.handle_ ) )

   def abortAuthen( self ):
      # TODO: implement this.  The underlying libtacplus API doesn't expose a
      # way to send the abort message to the server, but closing the connection
      # (perhaps by calling tac_close?) is probably good enough since we're not
      # worrying too much about trying to keep connections open, etc.  If we
      # do call tac_close we'll need to remember that and call tac_open later
      # if someone wants to keep using this Session.
      pass            

   def _syncReq( self, req, exceptionType ):
      def _check( r ):
         if r != 0:
            err = libtacplus.tac_strerror( self.handle_ )
            if not err:
               err = "Unknown error setting request parameters"
            raise exceptionType( err )
      if type( req ) == AuthenReq: # pylint: disable=unidiomatic-typecheck
         if req.message_:
            r = libtacplus.tac_set_msg( self.handle_, req.message_ )
            _check( r )
      elif req.findAttr( "service" ) is None:
         # require service if not authentication
         raise exceptionType( "No service specified in request" )
      if req.user_:
         u = req.user_[ :self.maxUsernameLength_ ]
         r = libtacplus.tac_set_user( self.handle_, u )
         _check( r )
      if req.port_:
         r = libtacplus.tac_set_port( self.handle_, req.port_ )
         _check( r ) 
      if req.remoteAddr_:
         r = libtacplus.tac_set_rem_addr( self.handle_, req.remoteAddr_ )
         _check( r )
      if req.privLevel_ is not None:
         r = libtacplus.tac_set_priv( self.handle_, req.privLevel_ )
         _check( r ) 
      libtacplus.tac_clear_avs( self.handle_ )
      for idx in sorted( req.attrs_ ):
         attr = req.attrs_[ idx ]
         # pylint: disable-next=consider-using-f-string
         arg = "%s%s%s" %( attr.name, "*" if attr.optional else "=",
                           attr.value )
         r = libtacplus.tac_set_av( self.handle_, idx, arg )
         _check( r ) 

   def createAuthzReq( self, authzService="shell",
                       authenMethod=libtacplus.TAC_AUTHEN_METH_TACACSPLUS,
                       authenType=libtacplus.TAC_AUTHEN_TYPE_ASCII,
                       authenService=libtacplus.TAC_AUTHEN_SVC_LOGIN ):

      req = AuthzReq( authenMethod=authenMethod, authenType=authenType,
                      authenService=authenService )
      req.serviceIs( authzService )
      self.currentAuthzReq_ = req # pylint: disable=attribute-defined-outside-init
      return req

   def sendAuthzReq( self, req ):
      assert req.__class__ == AuthzReq   
      assert req is self.currentAuthzReq_
      # I will attempt to perform the initial step at least once per server.  I
      # rely on libtacplus to advance to the next server when an error occurs.
      attempts = 0
      while True:
         attempts += 1
         traceX( TR_AUTHZ, "tac_create_author: authenMethod:", req.authenMethod_,
                 "authenType:", req.authenType_, "authenService:",
                 req.authenService_ )
         req.setCreated()
         r = libtacplus.tac_create_author( self.handle_, req.authenMethod_,
                                           req.authenType_, req.authenService_ ) 
         if r != 0:
            err = libtacplus.tac_strerror( self.handle_ )
            raise AuthorizationError( err )
         self._syncReq( req, AuthorizationError )
         try:
            r = libtacplus.tac_send_author( self.handle_ )
            if r == -1:
               err = libtacplus.tac_strerror( self.handle_ )
               if not err:
                  err = "Unknown error sending request"
               bt( TR_ERROR, "tac_send_author failed:", bv( err ) )
               raise AuthorizationError( err )
            status = libtacplus.TAC_AUTHOR_STATUS( r )
            av_count = libtacplus.TAC_AUTHEN_AV_COUNT( r )
            serverMsg = libtacplus.tac_get_msg( self.handle_ )
            traceX( TR_AUTHZ, "tac_send_author returned", r, ", status:",
                    status, "msg", serverMsg, "av_count:", av_count )
            if status == 0:
               # clearpass sends status 0 if the user login has expired;
               # retry in this case.

               # Manually advance the server.
               libtacplus.tac_set_cur_server( self.handle_,
                  ( libtacplus.tac_cur_server( self.handle_ ) + 1 ) %
                                              len( self.servers_ ) )
               # pylint: disable-next=consider-using-f-string
               raise AuthorizationError( "unexpected status 0: %s" % serverMsg )
            m_av, o_av = self._attributes( av_count )
            return ( status, serverMsg, m_av, o_av )
         except AuthorizationError as e:
            if attempts >= len( self.servers_ ):
               raise
            bt( TR_ERROR, "sendAuthzReq: retrying after:", bv( str( e ) ) )

   def createAcctReq( self, acctAction, acctService="shell",
                      authenMethod=libtacplus.TAC_AUTHEN_METH_TACACSPLUS,
                      authenType=libtacplus.TAC_AUTHEN_TYPE_ASCII,
                      authenService=libtacplus.TAC_AUTHEN_SVC_LOGIN,
                      taskId=None ):

      req = AcctReq( authenMethod, authenType, authenService, acctAction )
      if taskId is None:
         self.taskId += 1
         req.setAttr( "task_id", self.taskId )
      else:
         req.setAttr( "task_id", taskId )
      req.serviceIs( acctService )
      self.currentAcctReq_ = req
      return req

   def sendAcctReq( self, req ):
      assert req.__class__ == AcctReq   
      assert req is self.currentAcctReq_

      def sendAcctReqHelper():
         req.setCreated()
         r = libtacplus.tac_create_acct( self.handle_, req.acctAction_,
                                         req.authenMethod_,
                                         req.authenType_,
                                         req.authenService_ ) 
         if r != 0:
            err = libtacplus.tac_strerror( self.handle_ )
            raise AccountingError( err )
         self._syncReq( req, AccountingError )
         return libtacplus.tac_send_acct( self.handle_ )
         
      attempts = 0
      # If we have only one server we try again in case we fail.
      maxAttempts = max( 2, len( self.servers_ ) )
      while True:
         attempts += 1
         try:
            r = sendAcctReqHelper()

            if r == -1:
               err = libtacplus.tac_strerror( self.handle_ )
               err = err or "Unknown error sending request"
               bt( TR_ERROR, "tac_send_acct failed:", bv( err ) )
               raise AccountingError( err )
            
            status = libtacplus.TAC_ACCT_STATUS( r )
            traceX( TR_ACCT, "tac_send_acct returned", status )
            return ( status, libtacplus.tac_get_msg( self.handle_ ) )
         except AccountingError as e: # pylint: disable=unused-variable
            if attempts >= maxAttempts:
               raise
         
   def socketFd( self ):
      """ Returns the file descriptor for the socket opened by the
      lictacplus handle.  For testing purposes only!"""
      return libtacplus.tac_socket_from_handle( self.handle_ )

   def revalidateConnection( self ):
      libtacplus.tac_revalidate_connection( self.handle_ )

   def _attributes( self, count ):
      """Returns a dict containing the attribute-value pairs returned by the
      server in the most recent response.  If an attribute name appears more
      than once in the response, the value in the dict will be a list containing
      all of the values in order."""
      avpattern = re.compile( "(.+)([=*])(.*)" )
      mandatory_av = {}
      optional_av = {}
      for i in range( 0, count ):
         av = libtacplus.tac_get_av( self.handle_, i )
         traceX( TR_INFO, 'av pair', i, ':', av )
         m = avpattern.match( av )
         if m is None:
            traceX( TR_ERROR, 'bad av pair:', av )
            continue
         a, j, v = m.groups()
         if j == '=':
            r = mandatory_av
         else:
            r = optional_av
         current = r.get( a )
         if current is None:
            r[ a ] = v
         else:
            if type( current ) == list: # pylint: disable=unidiomatic-typecheck
               current.append( v )
            else:
               r[ a ] = [ current, v ]
      return mandatory_av, optional_av

