# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AgentCommandRequest
from BgpLib import (
   routeTargetToExtCommU64Value,
   vpnAfTypeMapInv,
)
from CliDynamicSymbol import CliDynamicPlugin
from CliPlugin.ArBgpCli import ArBgpCliCommand
from CliPlugin.BgpCliHelperCli import convertPeerAddr
from CliPlugin.RoutingBgpCli import configForVrf
from CliPlugin.RoutingBgpShowCli import getCommunityValuesScalarList
from Toggles import BgpCommonToggleLib

BgpCliHelperCliHandler = CliDynamicPlugin( "BgpCliHelperCliHandler" )

class VpnCliHelperCommand( ArBgpCliCommand ):
   def __init__( self, mode, command, nlriAfiSafi, **kwargs ):
      vrfName = kwargs.pop( 'vrfName', None )
      super().__init__(
         mode,
         command,
         vrfName=vrfName,
         nlriAfiSafi=nlriAfiSafi,
         transportAfi=None,
         disableFork=True )

      self.mode = mode

      if self.mode.session_.outputFormat_ == 'json':
         self.addParam( 'json' )

      for k, v in kwargs.items():
         if v is not None:
            self.addParam( k, v )

   def run( self, **kwargs ):
      AgentCommandRequest.runCliPrintSocketCommand( self.entityManager,
                                                    'BgpCliHelper',
                                                    self.command,
                                                    self.paramString(),
                                                    self.mode,
                                                    keepalive=True )

   @staticmethod
   def convertCommunityValues( fromKwargs, toKwargs ):
      value = fromKwargs.get( "communityValues", None )
      commValues = getCommunityValuesScalarList( value )
      toKwargs[ "commValues" ] = "|".join( [ str( c ) for c in commValues ] )
      exact = fromKwargs.get( "exact", None )
      if exact:
         toKwargs[ "standardCommunitiesExactMatch" ] = exact

   @staticmethod
   def convertExtCommunityValues( fromKwargs, toKwargs ):
      """ Input: fromKwargs: cli input params in following format:
      fromKwargs[ "extCommunityValues" ] = [ 1:1 2:2 ]
      toKwargs[ "extCommValues" ]='43189991171031040|21550481834311688|5629542483886'
      toKwargs[ "exact" ] = fromKwargs["exact" ], if it is set."""
      commValues = fromKwargs[ "extCommunityValues" ]
      exact = fromKwargs.get( "exact", None )
      if exact:
         toKwargs[ "extendedCommunitiesExactMatch" ] = exact
      extCommunities = ""
      for extCommValue in commValues:
         if extCommunities != "":
            extCommunities += "|"
         extCommunities += str( routeTargetToExtCommU64Value( extCommValue ) )

      toKwargs[ "extCommValues" ] = extCommunities

   @staticmethod
   def flattenArgs( fromKwargs, toKwargs ):
      C = VpnCliHelperCommand
      for k, v in fromKwargs.items():
         if v is None:
            continue
         if k in [ "prefix", "nlriTypeAndPrefixValues", "prefixValues" ] and (
            isinstance( v, dict ) ):
            C.flattenArgs( v, toKwargs )
         elif k == "commValuesAndExact":
            C.convertCommunityValues( v, toKwargs )
         elif k == "extCommValuesAndExact" and \
               "extCommValues" not in toKwargs:
            C.convertExtCommunityValues( v, toKwargs )
         elif k == "largeCommValuesAndExact":
            BgpCliHelperCliHandler.convertLargeCommunityValues(
               fromKwargs=v, kwargs=toKwargs, key="largeCommListVal" )
         elif k == "peerAddrValue":
            toKwargs[ "peerAddr" ] = v
            convertPeerAddr( toKwargs )
         elif k == "bgpRouteTypeValue":
            toKwargs[ "routeType" ] = v
         else:
            assert k not in toKwargs
            toKwargs[ k ] = v

# Enables/disables the discarding of unimported VPN paths, or sets it to the
# default value if 'discard' is None
def setRouteImportMatchFailureAction( mode, discard=None ):
   def getVal( defaultValue ):
      return discard if discard is not None else defaultValue

   config = configForVrf( mode.vrfName )
   if mode.addrFamily == 'vpn-ipv4':
      config.vpnPruningEnabledVpnV4 = getVal( config.vpnPruningEnabledVpnV4Default )
   elif mode.addrFamily == 'vpn-ipv6':
      config.vpnPruningEnabledVpnV6 = getVal( config.vpnPruningEnabledVpnV6Default )
   elif mode.addrFamily == 'evpn':
      config.vpnPruningEnabledEvpn = getVal( config.vpnPruningEnabledEvpnDefault )
   else:
      mode.addError( "Not supported in '%s' address-family mode", mode.addrFamily )

def handlerRouteImportMatchFailureDiscardCmd( mode, args ):
   setRouteImportMatchFailureAction( mode, discard=True )

def noOrDefaultHandlerRouteImportMatchFailureDiscardCmd( mode, args ):
   setRouteImportMatchFailureAction( mode, discard=None )

def setVpnClientIbgpAttributeTunnelingAction( mode, args, value ):
   config = configForVrf( mode.vrfName )
   allowedAfiSafis = [ 'vpn-ipv4', 'vpn-ipv6' ]
   if BgpCommonToggleLib.toggleEvpnIBgpAsPeCeProtocolEnabled():
      allowedAfiSafis += [ 'evpn' ]
   if mode.addrFamily not in allowedAfiSafis:
      mode.addError( "Invalid address family" )
   afiSafi = vpnAfTypeMapInv[ mode.addrFamily ]
   config.vpnClientIbgpAttributeTunneling[ afiSafi ] = value

def handlerVpnClientIbgpAttributeTunneling( mode, args ):
   setVpnClientIbgpAttributeTunnelingAction( mode, args, True )

def noOrDefaultHandlerVpnClientIbgpAttributeTunneling( mode, args ):
   setVpnClientIbgpAttributeTunnelingAction( mode, args, False )
