#!/usr/bin/env python3
#Copyright (c) 2010, 2011, 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Arnet
import ConfigMount
import LazyMount
from CliPlugin import PimCliLib
from CliPlugin.AclCli import getAclConfig
from CliPlugin.RouterMulticastCliLib import (
      AddressFamily,
      configGetters,
      doConfigMounts,
      getAddressFamilyFromMode,
)
import Tac
from IpLibConsts import DEFAULT_VRF

staticRpConfigColl = None
pimsmConfigColl = None
pimBidirStaticRpConfigColl = None
pimConfigRoot = None
aclListConfig = None

# Af Independent Config Types
PimsmConfigColl = "Routing::Pim::SparseMode::ConfigColl"

( pimsmConfigRoot,
  pimsmConfigRootFromMode,
  pimsmConfig,
  pimsmConfigFromMode ) = configGetters( PimsmConfigColl,
                                         collectionName='vrfConfig' )

SsmConfigColl = "Routing::Gmp::GmpSsmConfigColl"
( ssmConfigRoot,
  ssmConfigRootFromMode,
  ssmConfig,
  ssmConfigFromMode ) = configGetters( SsmConfigColl,
                                         collectionName='config' )

def getPimsmVrfConfig( vrfName ):
   pimsmVrfConfig = None
   pimsmVrfConfig = pimsmConfigColl.vrfConfig.get( vrfName )
   return pimsmVrfConfig

def _pimsmConfigCreation( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      pimsmConfig( af, vrfName )
      ssmConfig( af, vrfName )

def _pimsmConfigDeletion( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      if vrfName in configColl.vrfConfig:
         config = configColl.vrfConfig[ vrfName ]
         if vrfName != DEFAULT_VRF and config.isDefault():
            # only delete if there is no non-default config
            # and the VRF is not defined
            del configColl.vrfConfig[ vrfName ]

      configCollSsm = ssmConfigRoot( af )
      if vrfName in configCollSsm.config:
         config = configCollSsm.config[ vrfName ]
         if vrfName != DEFAULT_VRF and config.isDefault():
            del configCollSsm.config[ vrfName ]

def _canDeletePimsmVrfConfig( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      if vrfName in configColl.vrfConfig:
         config = configColl.vrfConfig[ vrfName ]
         if not config.isDefault():
            return False
   return True

def _cleanupPimsmConfig( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      if vrfName in configColl.vrfConfig:
         configColl.vrfConfig[ vrfName ].reset()


      configCollSsm = ssmConfigRoot( af )
      if vrfName in configCollSsm.config:
         configCollSsm.config[ vrfName ].reset()

# Clean up vrf agnostic config
def _cleanupPimsmConfigColl( vrfName ):
   # Cleanup only when called for default VRF as
   # vrf agnostic config is configured via default VRF
   if vrfName != DEFAULT_VRF:
      return
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      configColl = pimsmConfigRoot( af )
      configColl.reset()

# Check if acl rule has a multicast source address
def sourceIsMulticast( rule ):
   source = rule.filter.source

   # Need getRawAttribute, otherwise ipv4 returns a str instead of IpAddr
   address = source.getRawAttribute( "address" )
   return address.isMulticast

def sourceIsAny( rule, af ):
   source = rule.filter.source
   if af == 'ipv4':
      return source == Tac.Value( "Arnet::IpAddrWithFullMask" )
   else:
      return source == Tac.Value( "Arnet::Ip6AddrWithMask" )

def allSourcesAreMulticast( acl, af ):
   if af == 'ipv4':
      rules = acl.currCfg.ipRuleById
   else:
      rules = acl.currCfg.ip6RuleById

   return all( sourceIsMulticast( rule ) or sourceIsAny( rule, af )
               for rule in rules.values() )

def allPfxEntriesAreMcast( pfxListName ):
   pfxList = aclListConfig.prefixList[ pfxListName ]
   return all( entry.prefix.getRawAttribute( 'address' ).isMulticast
               for entry in pfxList.prefixEntry.values() )

#-----------------------------------------------------------------------------
# legacy: switch(config)# [no] ip pim ssm default source < groupAddr > [ sourceAddr ]
# (config-af)# [ no ] ssm default source < groupAddr > [ sourceAddr ]
# (config-af)# [ no ] ssm default source group filter < prefixList > [ sourceAddr ]

#-----------------------------------------------------------------------------
def setIpPimSsmConvert( mode, args ):
   groupAddr = args.get( 'GROUP' )
   sourceAddr = args[ 'SOURCE' ]
   pfxListName = args.get( 'PREFIX_LIST' )
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )

   if af == AddressFamily.ipv6:
      return

   if pimsmConfig_:
      if not pimsmConfig_.ssmConvertConfig:
         pimsmConfig_.ssmConvertConfig = ()
      config = pimsmConfig_.ssmConvertConfig
   else:
      return

   if groupAddr:
      try:
         ( source, group ) = PimCliLib.ipPimParseSg( groupAddr, sourceAddr )
      except ValueError:
         mode.addErrorAndStop( "Must enter a multicast group and an unicast source" )

      if group not in config.group:
         config.group.newMember( group )
         ssmConfig_.convertGroup.newMember( group )
      config.group[ group ].groupSource[ source ] = True
      ssmConfig_.convertGroup[ group ].groupSource[ source ] = True

   elif pfxListName:
      if ( pfxListName in aclListConfig.prefixList and
           not allPfxEntriesAreMcast( pfxListName ) ):
         mode.addWarning( f'{pfxListName} contains non-multicast rule(s)' )

      source = Arnet.IpGenAddr( sourceAddr )
      if pfxListName not in ssmConfig_.convertPfxList:
         ssmConfig_.convertPfxList.newMember( pfxListName )
      ssmConfig_.convertPfxList[ pfxListName ].source[ source ] = True

def noIpPimSsmConvert( mode, args ):
   groupAddr = args.get( 'GROUP' )
   sourceAddr = args.get( 'SOURCE' )
   pfxListName = args.get( 'PREFIX_LIST' )
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )

   if ( af == AddressFamily.ipv6 or
         not pimsmConfig_ or
         not pimsmConfig_.ssmConvertConfig ):
      return

   config = pimsmConfig_.ssmConvertConfig

   if not groupAddr and not pfxListName:
      config.group.clear()
      ssmConfig_.convertGroup.clear()
      ssmConfig_.convertPfxList.clear()
      return

   if groupAddr:
      try:
         ( source, group ) = PimCliLib.ipPimParseSg( groupAddr, sourceAddr )
      except ValueError:
         mode.addErrorAndStop( "Must enter a multicast group" )

      if group not in config.group:
         return 

      if source and group:
         del config.group[ group ].groupSource[ source ]
         if not config.group[ group ].groupSource:
            del config.group[ group ]
         del ssmConfig_.convertGroup[ group ].groupSource[ source ]
         if not ssmConfig_.convertGroup[ group ].groupSource:
            del ssmConfig_.convertGroup[ group ]
      elif group:
         del config.group[ group ]
         del ssmConfig_.convertGroup[ group ]

   elif pfxListName:
      if pfxListName not in ssmConfig_.convertPfxList:
         return

      if sourceAddr:
         source = Arnet.IpGenAddr( sourceAddr )
         del ssmConfig_.convertPfxList[ pfxListName ].source[ source ]
      else:
         del ssmConfig_.convertPfxList[ pfxListName ]

#------------------------------------------------------------------------------
# legacy: (config)# [ no ] ip pim ssm range <acl-name>
# legacy: (config)# [ no ] ip pim ssm range standard  
# (config-af)# [ no ] ssm range <acl-name>
# (config-af)# [ no ] ssm range standard 
#------------------------------------------------------------------------------
def setIpPimsmSsmFilter( mode, args ):
   aclNameOrStandard = args.get( 'ACL', 'standard' )
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )
   aclType = 'ip' if ( af == AddressFamily.ipv4 ) else 'ipv6'
   acl = getAclConfig( aclType ).get( aclNameOrStandard )

   if acl and not acl.standard:
      mode.addError( '%s is not a standard ACL' % aclNameOrStandard )
      return
   if aclNameOrStandard == 'standard':
      pimsmConfig_.ssmFilter = pimsmConfig_.ssmFilterStandard
      ssmConfig_.ssmFilter = ssmConfig_.ssmFilterStandard
   else:
      if acl and not allSourcesAreMulticast( acl, af ):
         mode.addWarning( '%s contains non-multicast rule(s)'
                          % aclNameOrStandard )
      pimsmConfig_.ssmFilter = aclNameOrStandard
      ssmConfig_.ssmFilter = aclNameOrStandard

def noIpPimsmSsmFilter( mode, args ):
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   ssmConfig_ = ssmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.ssmFilter = pimsmConfig_.ssmFilterDefault
   ssmConfig_.ssmFilter = ssmConfig_.ssmFilterDefault

#------------------------------------------------------------------------------
# legacy: (config)# [ no ] ip pim sparse-mode fast-reroute <acl-name>
# (config-af)# [ no ] fast-reroute <acl-name>
#------------------------------------------------------------------------------
def setIpPimsmFrrFilter( mode, args ):
   aclName = args[ 'ACL' ]
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   af = getAddressFamilyFromMode( mode, legacy=legacy )
   aclType = 'ip' if ( af == AddressFamily.ipv4 ) else 'ipv6'
   acl = getAclConfig( aclType ).get( aclName )

   if acl and not acl.standard:
      mode.addError( '%s is not a standard ACL' % aclName )
      return
   if acl and not allSourcesAreMulticast( acl, af ):
      mode.addWarning( '%s contains non-multicast rule(s)' \
            % aclName )
   pimsmConfig_.frrFilter = aclName

def noIpPimsmFrrFilter( mode, args ):
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.frrFilter = pimsmConfig_.frrFilterDefault

#------------------------------------------------------------------------------
# legacy:
#   (config)# [ no ] ip pim spt-threshold <infinity|zero> group-list <acl-name>
# (config-af)#[ no ] spt threshold <infinity|zero> match list <acl=name>
#------------------------------------------------------------------------------
def noPimsmSptThresh( mode, args ):
   legacy = 'spt-threshold' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return

   aclName = args.get( 'ACL' )
   if aclName:
      # Delete only the specific group-list entry the user specified
      if aclName in pimsmConfig_.drSwitchAcl:
         del pimsmConfig_.drSwitchAcl[ aclName ]
      else:
         mode.addError( "Unknown ACL: %s" % aclName )
   else:
      pimsmConfig_.drSwitch = 'immediate'

def setPimsmSptThresh( mode, args ):
   legacy = 'spt-threshold' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return

   thresh = 'never' if args[ 'THRESHOLD' ] == 'infinity'else 'immediate'
   aclName = args.get( 'ACL' )
   if aclName:
      pimsmConfig_.drSwitchAcl[ aclName ] = thresh
   else:
      pimsmConfig_.drSwitch = thresh

#------------------------------------------------------------------------------
# legacy: (config)# [ no ] ip pim sparse-mode sg-expiry-timer 120-259200
# (config-af)# [ no ] sg-expiry-timer 120-259200
#------------------------------------------------------------------------------
def setPimsmSgExpiry( mode, args ):
   legacy = 'pim' in args
   pimsmConfig_ = pimsmConfigFromMode( mode, legacy=legacy )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.sgExpiryTimer = args.get( 'EXPIRY',
                                          pimsmConfig_.sgExpiryTimerDefault )

#------------------------------------------------------------------------------
# (config-af)# [ no ] source advertisement inactivity timeout 45-30000 seconds
#------------------------------------------------------------------------------
def setPimsmSgAdvertiseInactiveInterval( mode, args ):
   pimsmConfig_ = pimsmConfigFromMode( mode )
   if pimsmConfig_ is None:
      return
   intv = args.get( 'ADVERTISE_INACTIVE',
                    pimsmConfig_.sgAdvertiseInactiveIntervalDefault )
   if intv > 300:
      mode.addWarning( "Configuring a higher advertise inactive interval"
                       " %d is not recommended. " % intv )
   pimsmConfig_.sgAdvertiseInactiveInterval = intv
   
#------------------------------------------------------------------------------
# (config-af)# [ no ] sso synchronization timeout 70-1200 seconds
#------------------------------------------------------------------------------
def setPimsmSyncTimeout( mode, args ):
   configColl = pimsmConfigRootFromMode( mode )
   if configColl is None:
      return
   configColl.ssoSyncTimeout = args.get( 'TIMEOUT',
                                         configColl.ssoSyncTimeoutDefault )

#------------------------------------------------------------------------------
# [no] make-before-break
#------------------------------------------------------------------------------
def setMbb( mode, args ):
   pimsmConfig_ = pimsmConfigFromMode( mode )
   if pimsmConfig_ is None:
      return
   pimsmConfig_.disableMbb = 'disabled' in args

def setRouteSgInstallThresh( mode, args ):
   pimsmConfig_ = pimsmConfigFromMode( mode )
   if pimsmConfig_:
      pimsmConfig_.sgInstallThresh = args.get( 'CRITERIA',
                                               pimsmConfig_.sgInstallThreshDefault )

def Plugin( entityManager ):
   #Af independent Config mounts
   configTypes = [ PimsmConfigColl, SsmConfigColl, ]
   doConfigMounts( entityManager, configTypes )

   global pimsmConfigColl
   pimsmConfigColl = ConfigMount.mount( entityManager, 
                           'routing/pim/sparsemode/config',
                           'Routing::Pim::SparseMode::ConfigColl', 'w' )

   global pimConfigRoot
   pimConfigRoot = LazyMount.mount( entityManager, 
         'routing/pim/config',
         'Routing::Pim::ConfigColl', 'r' )

   global aclListConfig
   aclListConfig = LazyMount.mount( entityManager,
                                    'routing/acl/config',
                                    'Acl::AclListConfig', 'r' )

   PimCliLib.pimSparseModeVrfConfiguredHook.addExtension( _pimsmConfigCreation )
   PimCliLib.pimSparseModeVrfDeletedHook.addExtension( _pimsmConfigDeletion )
   PimCliLib.canDeletePimSparseModeVrfHook.addExtension( _canDeletePimsmVrfConfig )
   PimCliLib.pimSparseModeCleanupHook.addExtension( _cleanupPimsmConfig )
   PimCliLib.pimSparseModeCleanupHook.addExtension( _cleanupPimsmConfigColl )
