# Copyright (c) 2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import CliSave, McastCommonCliLib, IpUtils
from CliMode.Msdp import RoutingMsdpBaseMode, RoutingMsdpVrfMode, \
      RoutingMsdpPeerMode
from CliSavePlugin.PimCliSaveLib import RouterPimBidirBaseConfigMode
from IpLibConsts import DEFAULT_VRF
import Tac
from functools import cmp_to_key

AddressFamily = Tac.Type( "Arnet::AddressFamily" )
MsdpLegacyConfig = Tac.Type( "Routing::Msdp::LegacyConfig" )

class RouterMsdpBaseConfigMode( RoutingMsdpBaseMode, CliSave.Mode ):
   def __init__( self, param ):
      RoutingMsdpBaseMode.__init__( self, param)
      CliSave.Mode.__init__( self, param )

class RouterMsdpVrfConfigMode( RoutingMsdpVrfMode, CliSave.Mode ):
   def __init__( self, param ):
      RoutingMsdpVrfMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

class RouterMsdpPeerConfigMode( RoutingMsdpPeerMode, CliSave.Mode ):
   def __init__( self, vrfNameAndPeer ):
      ( vrfName, peer ) = vrfNameAndPeer
      RoutingMsdpPeerMode.__init__( self, vrfName, peer )
      CliSave.Mode.__init__( self, ( vrfName, peer ) )

CliSave.GlobalConfigMode.addCommandSequence( 'Ip.Msdp',
                                          after=[ RouterPimBidirBaseConfigMode ] )
CliSave.GlobalConfigMode.addChildMode( RouterMsdpBaseConfigMode,
                                       after=[ 'Ip.Msdp' ] )
RouterMsdpBaseConfigMode.addCommandSequence( 'Ip.Msdp.config' )
RouterMsdpBaseConfigMode.addChildMode( RouterMsdpVrfConfigMode )
RouterMsdpBaseConfigMode.addChildMode( RouterMsdpPeerConfigMode )
RouterMsdpVrfConfigMode.addCommandSequence( 'Ip.Msdp.vrf.config' )
RouterMsdpVrfConfigMode.addChildMode( RouterMsdpPeerConfigMode )
RouterMsdpPeerConfigMode.addCommandSequence( 'Ip.Msdp.vrf.peer.config' )

def getRouterMode( root ):
   return root[ RouterMsdpBaseConfigMode ].getSingletonInstance()

def getRouterVrfMode( root, vrfName ):
   routerMode = getRouterMode( root )
   return routerMode[ RouterMsdpVrfConfigMode ].getOrCreateModeInstance(
                                                      vrfName )

def getCmdRoot( root, vrfName ):
   if vrfName == DEFAULT_VRF:
      cmds = getRouterMode( root )[ 'Ip.Msdp.config' ]
   else:
      cmds = getRouterVrfMode( root, vrfName )[ 'Ip.Msdp.vrf.config' ]

   return cmds

def getPeerRoot( root, vrfName, peer ):
   if vrfName == DEFAULT_VRF:
      parentMode = getRouterMode( root )
   else:
      parentMode = getRouterVrfMode( root, vrfName )
   peerMode = parentMode[ RouterMsdpPeerConfigMode ].getOrCreateModeInstance(
         ( vrfName, peer ) )
   return peerMode[ 'Ip.Msdp.vrf.peer.config' ]

msdpPeerCliSavers = []
def msdpPeerCliSaver( func ):
   msdpPeerCliSavers.append( func )

msdpAgentCliSavers = []
def msdpAgentCliSave( func ):
   msdpAgentCliSavers.append( func )

def saveDefaultPeer( peerConfig, cmds, saveAll, msdpConfigColl, vrfName ):
   for d in msdpConfigColl.vrfConfig[ vrfName ].defaultPeer.values():
      if d.ip == peerConfig.remote:
         if d.prefix == "":
            cmds.addCommand( 'default-peer' )
         else:
            cmds.addCommand( 'default-peer prefix-list %s' % d.prefix )

def saveMeshGroupPeer( peerConfig, cmds, saveAll, msdpConfigColl, vrfName ):
   for ( name, mg ) in \
       sorted( msdpConfigColl.vrfConfig[ vrfName ].meshGroup.items(),
               key=lambda x: x[ 0 ] ):
      if mg.member.get( peerConfig.remote ):
         cmds.addCommand( 'mesh-group %s' % ( name ) )

@msdpAgentCliSave
def saveGroupLimit( msdpConfigColl, vrfName, cmds, saveAll ):
   for sourcePrefix, saLimit in \
       msdpConfigColl.vrfConfig[ vrfName ].sourceSALimit.items():
      cmds.addCommand(
         f'group-limit {saLimit} source {sourcePrefix}' )

@msdpAgentCliSave
def saveOriginatorId( msdpConfigColl, vrfName, cmds, saveAll ):
   if msdpConfigColl.vrfConfig[ vrfName ].originatorId:
      cmds.addCommand( 'originator-id local-interface %s' % (
         msdpConfigColl.vrfConfig[ vrfName ].originatorId ) )
   elif False and saveAll: # pylint: disable=condition-evals-to-constant
      cmds.addCommand( 'no originator-id local-interface' )

@msdpAgentCliSave
def saveRejectedSALimit( msdpConfigColl, vrfName, cmds, saveAll ):
   if ( msdpConfigColl.vrfConfig[ vrfName ].rejectedSALimit !=
        msdpConfigColl.vrfConfig[ vrfName ].rejectedSALimitDefault or saveAll ):
      cmds.addCommand( 'rejected-limit %d' %
                       msdpConfigColl.vrfConfig[ vrfName ].rejectedSALimit )

@msdpAgentCliSave
def saveMsdpForwardFlag( msdpConfigColl, vrfName, cmds, saveAll ):
   if msdpConfigColl.vrfConfig[ vrfName ].encapDataPkts:
      cmds.addCommand( 'forward register-packets' )
   elif saveAll:
      cmds.addCommand( 'no forward register-packets' )

@msdpAgentCliSave
def saveConnRetryInterval ( msdpConfigColl, vrfName, cmds, saveAll ):
   if ( msdpConfigColl.vrfConfig[ vrfName ].connRetryPeriod !=
        msdpConfigColl.vrfConfig[ vrfName ].connRetryPeriodDefault or saveAll ):
      cmds.addCommand( 'connection retry interval %d' %
                       msdpConfigColl.vrfConfig[ vrfName ].connRetryPeriod )

@msdpPeerCliSaver
def saveConnSrc( peerConfig, cmds, saveAll ):
   if peerConfig.connSrc != "":
      cmd = 'local-interface %s' % peerConfig.connSrc
      cmds.addCommand( cmd )

@msdpPeerCliSaver
def saveKeepAlive( peerConfig, cmds, saveAll ):
   if ( peerConfig.keepalive != peerConfig.keepaliveDefault or
        peerConfig.holdtime != peerConfig.holdtimeDefault or saveAll ):
      cmds.addCommand( 'keepalive %d %d' %
                       ( peerConfig.keepalive, peerConfig.holdtime ) )

@msdpPeerCliSaver
def saveSAFilterIn( peerConfig, cmds, saveAll ):
   if peerConfig.saFilterIn:
      cmds.addCommand( 'sa-filter in list %s' %
                       ( peerConfig.saFilterIn ) )

@msdpPeerCliSaver
def saveSAFilterOut( peerConfig, cmds, saveAll ):
   if peerConfig.saFilterOut:
      cmds.addCommand( 'sa-filter out list %s' %
                       ( peerConfig.saFilterOut ) )

@msdpPeerCliSaver
def saveDescription( peerConfig, cmds, saveAll ):
   if peerConfig.description:
      cmds.addCommand( 'description %s' % peerConfig.description )

@msdpPeerCliSaver
def saveDisable( peerConfig, cmds, saveAll ):
   if peerConfig.shutdown:
      cmds.addCommand( 'disabled' )

@msdpPeerCliSaver
def saveSALimit( peerConfig, cmds, saveAll ):
   if peerConfig.peerSALimit != peerConfig.peerSALimitDefault:
      cmds.addCommand( 'sa-limit %s' % peerConfig.peerSALimit )
   elif saveAll:
      cmds.addCommand( 'no sa-limit 10' )

@CliSave.saver( 'Routing::Msdp::ConfigColl', 'routing/msdp/config',
                requireMounts=( 'routing/hardware/statuscommon',
                                'routing/hardware/status' ) )
def saveConfig( msdpConfigColl, root, requireMounts, options ):
   saveAll = options.saveAll

   # Save the default config only if the platform supports multicast routing
   if not McastCommonCliLib.mcastRoutingSupported(
         None,
         requireMounts[ 'routing/hardware/status' ] ) and \
         not options.saveAllDetail:
      saveAll = False

   for vrfName, vrfConfig in msdpConfigColl.vrfConfig.items():

      if not vrfConfig.isDefault() or saveAll:
         cmds = getCmdRoot( root, vrfName )

         for peerIp in sorted( ( x.remote for x in
                           msdpConfigColl.vrfConfig[ vrfName ].peerConfig.values() ),
                           key=cmp_to_key( IpUtils.compareIpAddress ) ):
            peer = msdpConfigColl.vrfConfig[ vrfName ].peerConfig.get( peerIp )
            peerCmds = getPeerRoot( root, vrfName, peer.remote )

            saveDefaultPeer( peer, peerCmds, saveAll,
                             msdpConfigColl, vrfName )
            saveMeshGroupPeer( peer, peerCmds, saveAll,
                               msdpConfigColl, vrfName )

            for saver in msdpPeerCliSavers:
               saver( peer, peerCmds, saveAll )


         for func in msdpAgentCliSavers:
            func( msdpConfigColl, vrfName, cmds, saveAll )
