#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliSave
import CliSavePlugin.MrouteCliSave
from CliSavePlugin.MrouteCliSave import RouterMulticastBaseConfigMode
from CliSavePlugin.MrouteCliSave import RouterMulticastAfConfigMode
from CliSavePlugin.MrouteCliSave import RouterMulticastVrfConfigMode
from CliMode.Mroute import RoutingMulticastAfMplsStaticTunnelMode
from IpLibConsts import DEFAULT_VRF

class RouterMulticastMplsStaticTunnelConfigMode(
      RoutingMulticastAfMplsStaticTunnelMode, CliSave.Mode ):
   def __init__( self, vrfAfSrcGrpSet ):
      ( vrfName, af, source, group ) = vrfAfSrcGrpSet
      RoutingMulticastAfMplsStaticTunnelMode.__init__( self, vrfName, af, source,
                                                       group )
      CliSave.Mode.__init__( self, ( vrfName, af, source, group ) )

RouterMulticastAfConfigMode.addChildMode( RouterMulticastMplsStaticTunnelConfigMode )
RouterMulticastMplsStaticTunnelConfigMode.addCommandSequence(
   'MplsStaticTunnel.vrf.af.config' )

def getCmdRoot( root, vrfName, af, source, group ):
   rootMode = root[ RouterMulticastBaseConfigMode ].getSingletonInstance()
   if vrfName == DEFAULT_VRF:
      if not af:
         return rootMode
      else:
         afMode = rootMode[ RouterMulticastAfConfigMode ].getOrCreateModeInstance(
            ( vrfName, af ) )
   else:
      vrfMode = rootMode[ RouterMulticastVrfConfigMode ].\
                getOrCreateModeInstance( vrfName )
      afMode = vrfMode[ RouterMulticastAfConfigMode ].getOrCreateModeInstance(
         ( vrfName, af ) )
   return afMode[ RouterMulticastMplsStaticTunnelConfigMode ].\
      getOrCreateModeInstance( ( vrfName, af, source, group ) )[
         'MplsStaticTunnel.vrf.af.config' ]

@CliSave.saver( 'Tunnel::Static::Multicast::Config',
                'tunnel/static/multicast/config' )
def saveStaticTunnelConfig( staticTunnelConfig, root, requireMounts, options ):
   saveAllDetail = options.saveAllDetail

   for vrfName in staticTunnelConfig.entry:
      staticTunnelConfigEntry = staticTunnelConfig.entry[ vrfName ]
      af = staticTunnelConfigEntry.af

      for sourceGroup in staticTunnelConfigEntry.sourceGroupEntry:
         if staticTunnelConfigEntry.sourceGroupEntry[ sourceGroup ].via \
            or saveAllDetail:

            cmds = getCmdRoot( root, vrfName, af, sourceGroup.source,
                               sourceGroup.group )
            for via in staticTunnelConfigEntry.sourceGroupEntry[ sourceGroup ].via:
               savedCmd = []
               savedCmd.append( 'next-hop' )
               savedCmd.append( via.nexthop.stringValue )
               savedCmd.append( via.intfId )
               savedCmd.append( 'label-stack' )
               labelStack = []
               labelsObj = via.labels
               for idx in reversed( list( range( labelsObj.stackSize ) ) ):
                  labelStack.append( str( labelsObj.labelStack( idx ) ) )
                  savedCmd += labelStack
                  cmds.addCommand( ' '.join( savedCmd ) )

RouterMulticastBaseConfigMode.addCommandSequence( 'MvpnConfig.config',
                                                  after=[ 'Multicast.config' ] )
@CliSave.saver( 'Routing::Multicast::MvpnConfig',
                'routing/multicast/mvpn/config' )
def saveMvpnConfig( mvpnConfig, root, requireMounts, options ):
   saveAll = options.saveAll
   saveAllDetail = options.saveAllDetail

   savedCmd = None
   if mvpnConfig.mvpnStatic:
      savedCmd = 'mvpn ipv4 static pmsi'
   elif saveAll or saveAllDetail:
      savedCmd = 'no mvpn ipv4 static pmsi'

   if savedCmd is not None:
      cmds = CliSavePlugin.MrouteCliSave.getCmdRoot( root, DEFAULT_VRF, None, None )
      cmds.addCommand( savedCmd )

RouterMulticastVrfConfigMode.addCommandSequence( 'MvpnStaticVrfLabel.config',
                                                  after=[ 'Multicast.vrf.config' ] )
@CliSave.saver( 'Mpls::VrfLabelConfigInput',
                'routing/mpls/multicast/vrfLabel/input/cli' )
def saveMvpnVrfLabel( entity, root, requireMounts, options ):
   staticVrfLabelCmd = 'mpls static label {}'

   for vrfLabel in entity.vrfLabel.values():
      cmd = staticVrfLabelCmd.format( vrfLabel.label )
      cmds = CliSavePlugin.MrouteCliSave.getCmdRoot( root, vrfLabel.vrfName, None,
                                                     None )
      cmds.addCommand( cmd )

