# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from CliSavePlugin.SrTePolicyCliSave import (
      SrTeMode,
      getOrCreateSrTeModeInst,
   )
from Arnet import IpGenAddr
from SrTePolicyLib import MplsLabel
import CliSave
from CliMode.SrTePolicy import ReplicationSegmentModeBase
from Toggles.SrTePolicyToggleLib import toggleReplicationSegmentEnabled

class ReplicationSegmentMode( ReplicationSegmentModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ReplicationSegmentModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

if toggleReplicationSegmentEnabled():
   SrTeMode.addChildMode( ReplicationSegmentMode )
ReplicationSegmentMode.addCommandSequence( 'SrTePolicy.replicationsegment.config' )

@CliSave.saver( 'SrTePolicy::SrP2mp::Config',
                'te/segmentrouting/replicationsegment/config' )
def saveReplicationSegment( config, root, requireMounts, options ):
   saveAll = options.saveAll
   srTeMode = None
   for repId in sorted( config.staticReplicationSegment ):
      if srTeMode is None:
         srTeMode = getOrCreateSrTeModeInst( root )
      repSeg = config.staticReplicationSegment[ repId ]
      repSegMode = srTeMode[ ReplicationSegmentMode ]. \
      getOrCreateModeInstance( repSeg )
      cmds = repSegMode[ 'SrTePolicy.replicationsegment.config' ]
      if repSeg.repSegKey.rootAddress:
         cmds.addCommand( 'root %s tree-id %s instance-id %s'
               % ( str( repSeg.repSegKey.rootAddress ),
                        repSeg.repSegKey.treeId,
                        repSeg.repSegKey.instanceId ) )
      if repSeg.desc:
         cmds.addCommand( 'desc %s' % repSeg.desc )
      if repSeg.replicationSegmentName:
         cmds.addCommand( 'name %s' % repSeg.replicationSegmentName )
      if repSeg.replicationSid != MplsLabel.null:
         cmds.addCommand( 'replication-sid %d' %  repSeg.replicationSid )
      elif saveAll:
         cmds.addCommand( 'no replication-sid' )
      for segStack in sorted( repSeg.staticSegmentList ):
         segList = repSeg.staticSegmentList[ segStack ]
         labelStackStr = ""
         for i in range( segStack.labelStack.stackSize ):
            labelStackStr += " "
            labelStackStr += str( segStack.labelStack.labelStack( i ) )
         cmdStr = 'segment-list index %d' % segList.index
         if segStack.nodeAddress == IpGenAddr():
            cmdStr += ' label-stack %s' % labelStackStr
         else:
            cmdStr += ' downstream-rsid %s' % labelStackStr
            cmdStr += ' next-hop %s' % str( segStack.nodeAddress )
         if segList.desc:
            cmdStr += ' desc %s' % segList.desc
         cmds.addCommand( cmdStr )
