# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliSave
import Tac
from CliMode.Dps import ( RouterPathSelectionModeBase,
                          DpsPathGroupModeBase,
                          DpsPathGroupRemoteRouterBase,
                          DpsPolicyModeBase,
                          DpsPolicyRuleKeyBase,
                          DpsPolicyDefaultRuleBase,
                          DpsLoadBalanceProfileConfigModeBase,
                          DpsVrfConfigModeBase,
                          DpsPathGroupStunConfigModeBase,
                          DpsPeerDynamicConfigModeBase,
                          DpsIntfModeBase,
                          DEFAULT_PRIORITY,
                          DEFAULT_JITTER, JITTER_SCALE,
                          DEFAULT_LATENCY, LATENCY_SCALE,
                          DEFAULT_LOSSRATE, LOSS_RATE_SCALE, LOSS_RATE_ADJUSTMENT,
                          DEFAULT_UDP_PORT, DEFAULT_KEEPALIVE_INTERVAL,
                          DEFAULT_FEEDBACK_SCALE )
from Toggles.WanTECommonToggleLib import toggleAvtLowestLoadMetricEnabled
from Toggles.DpsToggleLib import toggleFlowAssignmentLanEnabled

Constants = Tac.Type( "Dps::DpsConstants" )
constants = Constants()
MetricOrder = Tac.Type( "Avt::MetricOrder" )

class RouterPathSelectionSaveMode( RouterPathSelectionModeBase, CliSave.Mode ):
   def __init__( self, param ):
      RouterPathSelectionModeBase.__init__( self )
      CliSave.Mode.__init__( self, param )

CliSave.GlobalConfigMode.addChildMode( RouterPathSelectionSaveMode )
RouterPathSelectionSaveMode.addCommandSequence( 'Dps.RouterPathSelection' )

class DpsPathGroupSaveMode( DpsPathGroupModeBase, CliSave.Mode ):
   def __init__( self, param ):
      self.pathGroupName, self.pathGroupId = param
      DpsPathGroupModeBase.__init__( self, self.pathGroupName, self.pathGroupId )
      CliSave.Mode.__init__( self, param )

RouterPathSelectionSaveMode.addChildMode( DpsPathGroupSaveMode )
DpsPathGroupSaveMode.addCommandSequence( 'Dps.DpsPathGroup' )

class DpsIntfSaveMode( DpsPathGroupStunConfigModeBase, CliSave.Mode ):
   def __init__( self, param ):
      self.pgName, self.key, self.vrf, self.publicIp = param
      DpsPathGroupStunConfigModeBase.__init__( self,
                                               self.pgName,
                                               self.key,
                                               'interface',
                                               self.vrf,
                                               self.publicIp )
      CliSave.Mode.__init__( self, param )

DpsPathGroupSaveMode.addChildMode( DpsIntfSaveMode )
DpsIntfSaveMode.addCommandSequence( 'Dps.DpsIntf' )

class DpsPathGroupIpSaveMode( DpsPathGroupStunConfigModeBase, CliSave.Mode ):
   def __init__( self, param ):
      self.pgName, self.key, self.vrf, self.publicIp = param
      DpsPathGroupStunConfigModeBase.__init__( self,
                                               self.pgName,
                                               self.key,
                                               'ip',
                                               self.vrf,
                                               self.publicIp )
      CliSave.Mode.__init__( self, param )

DpsPathGroupSaveMode.addChildMode( DpsPathGroupIpSaveMode,
                                   after=[ DpsIntfSaveMode ] )
DpsPathGroupIpSaveMode.addCommandSequence( 'Dps.DpsPathGroupIp' )

class DpsPathGroupRemoteRouterSaveMode( DpsPathGroupRemoteRouterBase,
                                        CliSave.Mode ):
   def __init__( self, param ):
      self.routerIp, self.pgName = param
      DpsPathGroupRemoteRouterBase.__init__( self, self.routerIp, self.pgName )
      CliSave.Mode.__init__( self, param )

DpsPathGroupSaveMode.addChildMode( DpsPathGroupRemoteRouterSaveMode,
                                   after=[ DpsPathGroupIpSaveMode ] )
DpsPathGroupRemoteRouterSaveMode.addCommandSequence(
                                       'Dps.DpsPathGroupRemoteRouter' )

class DpsPathGroupPeerDynamicSaveMode( DpsPeerDynamicConfigModeBase,
                                       CliSave.Mode ):
   def __init__( self, param ):
      self.pgName = param
      DpsPeerDynamicConfigModeBase.__init__( self, self.pgName )
      CliSave.Mode.__init__( self, param )

DpsPathGroupSaveMode.addChildMode( DpsPathGroupPeerDynamicSaveMode )
DpsPathGroupPeerDynamicSaveMode.addCommandSequence( 'Dps.DpsPathGroupPeerDynamic' )

class DpsPolicySaveMode( DpsPolicyModeBase, CliSave.Mode ):
   def __init__( self, param ):
      DpsPolicyModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

RouterPathSelectionSaveMode.addChildMode( DpsPolicySaveMode )
DpsPolicySaveMode.addCommandSequence( 'Dps.DpsPolicy' )

class DpsPolicyRuleKeySaveMode( DpsPolicyRuleKeyBase,
                                        CliSave.Mode ):
   def __init__( self, param ):
      DpsPolicyRuleKeyBase.__init__( self, param[ 0 ], param[ 1 ], param[ 2 ] )
      CliSave.Mode.__init__( self, param )

DpsPolicySaveMode.addChildMode( DpsPolicyRuleKeySaveMode )
DpsPolicyRuleKeySaveMode.addCommandSequence(
                                       'Dps.DpsPolicyRuleKey' )

class DpsPolicyDefaultRuleSaveMode( DpsPolicyDefaultRuleBase,
                                        CliSave.Mode ):
   def __init__( self, param ):
      DpsPolicyDefaultRuleBase.__init__( self, param[ 0 ], param[ 1 ] )
      CliSave.Mode.__init__( self, param )

DpsPolicySaveMode.addChildMode( DpsPolicyDefaultRuleSaveMode )
DpsPolicyDefaultRuleSaveMode.addCommandSequence(
                                       'Dps.DpsPolicyDefaultRule' )

class DpsLoadBalanceProfileConfigSaveMode( DpsLoadBalanceProfileConfigModeBase,
                                          CliSave.Mode ):
   def __init__( self, param ):
      DpsLoadBalanceProfileConfigModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

RouterPathSelectionSaveMode.addChildMode( DpsLoadBalanceProfileConfigSaveMode,
                                          after=[ DpsPathGroupSaveMode ] )
DpsLoadBalanceProfileConfigSaveMode.addCommandSequence(
        'Dps.DpsLoadBalanceProfile' )

class DpsVrfConfigSaveMode( DpsVrfConfigModeBase, CliSave.Mode ):
   def __init__( self, param ):
      DpsVrfConfigModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, param )

RouterPathSelectionSaveMode.addChildMode( DpsVrfConfigSaveMode )
DpsVrfConfigSaveMode.addCommandSequence( 'Dps.DpsVrfConfig' )

class DpsIntfSpeedSaveMode( DpsIntfModeBase, CliSave.Mode ):
   def __init__( self, param ):
      self.intfId = param
      DpsIntfModeBase.__init__( self, self.intfId )
      CliSave.Mode.__init__( self, param )

if toggleAvtLowestLoadMetricEnabled():
   RouterPathSelectionSaveMode.addChildMode( DpsIntfSpeedSaveMode )
   DpsIntfSpeedSaveMode.addCommandSequence( 'Dps.DpsIntfSpeed' )

@CliSave.saver( 'Dps::DpsCliConfig', 'dps/input/cli' )
def saveConfig( entity, root, requireMounts, options ):
   defaultVrf = Tac.Type( "L3::VrfName" ).defaultVrf

   # "router path-selection" sets dpsConfigured to True
   if not entity.dpsConfigured:
      return

   # when this is true, show even default configs
   saveAll = options.saveAll
   dpsMode = root[ RouterPathSelectionSaveMode ].getSingletonInstance()
   dpsPathSelectionMode = dpsMode[ 'Dps.RouterPathSelection' ]

   udpPort = entity.udpPortConfig
   if udpPort != DEFAULT_UDP_PORT or saveAll:
      encapCmds = dpsPathSelectionMode
      encapCmds.addCommand( 'encapsulation path-telemetry udp port %d' % udpPort )

   if entity.peerDynamicStun:
      dpsPathSelectionMode.addCommand( 'peer dynamic source stun' )
   elif saveAll:
      dpsPathSelectionMode.addCommand( 'no peer dynamic source stun' )

   # save tcp mss ingress config
   tcpMssIngress = entity.tcpMssIngressConfig
   tcpMssCmds = dpsPathSelectionMode
   if tcpMssIngress == constants.tcpMssIngressRewriteAuto:
      tcpMssCmds.addCommand( 'tcp mss ceiling ipv4 ingress' )
   elif tcpMssIngress != constants.tcpMssIngressRewriteDisabled:
      tcpMssCmds.addCommand( 'tcp mss ceiling ipv4 %d ingress' % tcpMssIngress )
   elif saveAll:
      tcpMssCmds.addCommand( 'no tcp mss ceiling' )

   # save mtu discovery config
   if entity.mtuDiscInterval != constants.mtuDiscDefaultInterval or saveAll:
      dpsPathSelectionMode.addCommand(
            'mtu discovery interval %d seconds' % entity.mtuDiscInterval )

   # save icmp fragmentation-needed
   mtuDiscHostCmds = dpsPathSelectionMode
   if entity.icmpFragNeededEnabled:
      cmd = 'mtu discovery hosts'
      if ( entity.icmpFragNeededRateLimit !=
            constants.icmpFragNeededDefaultRateLimit ):
         cmd += ( " fragmentation-needed rate-limit %d packets-per-second" %
                  entity.icmpFragNeededRateLimit )
      mtuDiscHostCmds.addCommand( cmd )
   elif saveAll:
      mtuDiscHostCmds.addCommand( 'no mtu discovery hosts' )

   for pgName in entity.pathGroupConfig:
      pathGroupId = entity.pathGroupConfig[ pgName ].pathGroupId
      pgMode = dpsMode[ DpsPathGroupSaveMode ].getOrCreateModeInstance(
                                                      ( pgName, pathGroupId ) )
      pgCmds = pgMode[ 'Dps.DpsPathGroup' ]
      pgCfg = entity.pathGroupConfig[ pgName ]
      if pgCfg.ipsecProfile != "":
         pgCmds.addCommand( 'ipsec profile %s' % pgCfg.ipsecProfile )
      elif saveAll:
         pgCmds.addCommand( 'no ipsec profile' )
      intfConfig = pgCfg.localIntf
      for intf in intfConfig:
         pubIp = intfConfig[ intf ].publicIp
         vrf = intfConfig[ intf ].vrfName
         params = ( pgName, intf, vrf, pubIp )
         intfMode = pgMode[ DpsIntfSaveMode ].getOrCreateModeInstance( params )
         intfCmds = intfMode[ 'Dps.DpsIntf' ]
         if intf in pgCfg.intfStunConfig:
            stunProfiles = pgCfg.intfStunConfig[ intf ]
            profiles = sorted( list( stunProfiles.serverProfile ) )
            if profiles:
               profileList = ' '.join( profiles )
               cmd = 'stun server-profile %s' % profileList
               intfCmds.addCommand( cmd )
            elif saveAll:
               intfCmds.addCommand( 'no stun server-profile' )

      for ip in pgCfg.localIp:
         pubIp = pgCfg.localIp[ ip ]
         params = ( pgName, ip, defaultVrf, pubIp )
         ipMode = pgMode[ DpsPathGroupIpSaveMode ].getOrCreateModeInstance( params )
         ipCmds = ipMode[ 'Dps.DpsPathGroupIp' ]
         if ip in pgCfg.ipStunConfig:
            stunProfiles = pgCfg.ipStunConfig[ ip ]
            profiles = sorted( list( stunProfiles.serverProfile ) )
            if profiles:
               profileList = ' '.join( profiles )
               cmd = 'stun server-profile %s' % profileList
               ipCmds.addCommand( cmd )
            elif saveAll:
               intfCmds.addCommand( 'no stun server-profile' )

      if pgCfg.remoteDynamic:
         dynPeerMode = \
               pgMode[ DpsPathGroupPeerDynamicSaveMode ].getOrCreateModeInstance(
                                                            pgName )
         dynPeerCmds = dynPeerMode[ 'Dps.DpsPathGroupPeerDynamic' ]
         if pgCfg.preferLocalIp:
            dynPeerCmds.addCommand( 'ip local' )
         if pgCfg.dynamicPeerIpsec == 'ipsecEnabled':
            dynPeerCmds.addCommand( 'ipsec' )
         elif pgCfg.dynamicPeerIpsec == 'ipsecDisabled':
            dynPeerCmds.addCommand( 'ipsec disabled' )
         elif saveAll:
            dynPeerCmds.addCommand( 'no ipsec' )
      elif saveAll:
         pgCmds.addCommand( 'no peer dynamic' )

      if pgCfg.mssEgress:
         pgCmds.addCommand( 'tcp mss ceiling ipv4 %s egress' % pgCfg.mssEgress )
      elif saveAll:
         pgCmds.addCommand( 'no tcp mss ceiling' )

      for pathViaPair in pgCfg.pathViaPairSet:
         pgCmds.addCommand( 'import path-group remote %s local %s' %
                            ( pathViaPair.remotePg, pathViaPair.localPg ) )

      if pgCfg.keepaliveInterval == -1:
         pgCmds.addCommand( 'keepalive interval auto' )
      elif ( pgCfg.keepaliveInterval != DEFAULT_KEEPALIVE_INTERVAL or
           pgCfg.feedbackScale != DEFAULT_FEEDBACK_SCALE ) and \
           pgCfg.keepaliveInterval > 0:
         pgCmds.addCommand( 'keepalive interval %s milliseconds failure-'
                           'threshold %s intervals' % ( pgCfg.keepaliveInterval,
                                                        pgCfg.feedbackScale ) )
      elif saveAll:
         pgCmds.addCommand( 'keepalive interval %s milliseconds failure-'
                            'threshold %s intervals' % ( DEFAULT_KEEPALIVE_INTERVAL,
                                                         DEFAULT_FEEDBACK_SCALE ) )
      # save mtu config
      if pgCfg.mtu == constants.pmtuDisabled:
         pgCmds.addCommand( 'mtu disabled' )
      elif pgCfg.mtu > 0:
         pgCmds.addCommand( 'mtu %d' % pgCfg.mtu )
      elif saveAll:
         pgCmds.addCommand( 'no mtu' )

      # save mtu discovery interval config
      if pgCfg.mtuDiscInterval != 0:
         pgCmds.addCommand( 'mtu discovery interval %d seconds' %
                            pgCfg.mtuDiscInterval )
      elif saveAll:
         pgCmds.addCommand( 'no mtu discovery interval' )

      if toggleFlowAssignmentLanEnabled():
         if pgCfg.flowAssignmentLan:
            pgCmds.addCommand( 'flow assignment lan' )
         elif saveAll:
            pgCmds.addCommand( 'no flow assignment lan' )

      for viaCfg in pgCfg.remoteViaConfig:
         viaMode = \
            pgMode[ DpsPathGroupRemoteRouterSaveMode ].getOrCreateModeInstance(
                                                         ( viaCfg, pgName ) )
         viaCmds = viaMode[ 'Dps.DpsPathGroupRemoteRouter' ]
         viaEntry = pgCfg.remoteViaConfig[ viaCfg ]
         peerName = viaEntry.peerName
         if peerName and peerName != "" and \
            peerName != viaEntry.remoteAddr.stringValue:
            viaCmds.addCommand( 'name %s' % peerName )
         for via in viaEntry.remoteEncap:
            viaCmds.addCommand( 'ipv4 address %s' % via )

   for policy in entity.policyConfig:
      polMode = dpsMode[ DpsPolicySaveMode ].getOrCreateModeInstance( policy )
      polCfg = entity.policyConfig[ policy ]
      for ruleKey in polCfg.appProfilePolicyRuleList:
         ruleKeyEntry = polCfg.appProfilePolicyRuleList[ ruleKey ]
         ruleKeyMode = \
            polMode[ DpsPolicyRuleKeySaveMode ].getOrCreateModeInstance(
                     ( policy, ruleKey, ruleKeyEntry.appProfileName ) )
         ruleKeyCmds = ruleKeyMode[ 'Dps.DpsPolicyRuleKey' ]
         if ruleKeyEntry.actionName != "":
            ruleKeyCmds.addCommand( 'load-balance %s' % ruleKeyEntry.actionName )
         elif saveAll:
            ruleKeyCmds.addCommand( 'no load-balance' )

      if polCfg.defaultRuleCfgd:
         defaultLbGrp = polCfg.defaultActionName
         defaultRuleMode = \
            polMode[ DpsPolicyDefaultRuleSaveMode ].getOrCreateModeInstance(
                  ( policy, polCfg.defaultRuleCfgd ) )
         defaultRuleCmds = defaultRuleMode[ 'Dps.DpsPolicyDefaultRule' ]
         if defaultLbGrp != "":
            defaultRuleCmds.addCommand( 'load-balance %s' % defaultLbGrp )
         elif saveAll:
            defaultRuleCmds.addCommand( 'no load-balance' )

   for profileName in entity.loadBalanceProfile:
      profileMode = dpsMode[ DpsLoadBalanceProfileConfigSaveMode ].\
                                        getOrCreateModeInstance( profileName )
      profileCmds = profileMode[ 'Dps.DpsLoadBalanceProfile' ]
      profile = entity.loadBalanceProfile[ profileName ]
      if profile.latency != DEFAULT_LATENCY:
         profileCmds.addCommand( 'latency %d' %
                 ( profile.latency // LATENCY_SCALE ) )
      elif saveAll:
         profileCmds.addCommand( 'no latency' )
      if profile.jitter != DEFAULT_JITTER:
         profileCmds.addCommand( 'jitter %d' %
                 ( profile.jitter // JITTER_SCALE ) )
      elif saveAll:
         profileCmds.addCommand( 'no jitter' )
      if profile.lossRate != DEFAULT_LOSSRATE:
         loss = profile.lossRate
         loss -= LOSS_RATE_ADJUSTMENT
         profileCmds.addCommand( 'loss-rate %.2f' %
                 ( float( loss ) / LOSS_RATE_SCALE ) )
      elif saveAll:
         profileCmds.addCommand( 'no loss-rate' )

      if profile.hopCountLowest:
         profileCmds.addCommand( 'hop count lowest' )

      # save load-balance policy sorted by priority and path-group name
      for pgName, priority in sorted( profile.pathGroupPriority.items(),
                                      key=lambda prio:( prio[ 1 ], prio[ 0 ] ) ):
         cmd = 'path-group %s' % pgName
         if priority != DEFAULT_PRIORITY:
            cmd += ' priority %d' % priority
         profileCmds.addCommand( cmd )

   for vrfName in entity.vrfConfig:
      vrfMode = dpsMode[ DpsVrfConfigSaveMode ].getOrCreateModeInstance( vrfName )
      vrfCmds = vrfMode[ 'Dps.DpsVrfConfig' ]

      vrfCfg = entity.vrfConfig[ vrfName ]

      if vrfCfg.policyName:
         vrfCmds.addCommand( 'path-selection-policy %s' % vrfCfg.policyName )
      elif saveAll:
         vrfCmds.addCommand( 'no path-selection-policy' )

   if toggleAvtLowestLoadMetricEnabled():
      for intfId in entity.intfConfig:
         intfMode = dpsMode[ DpsIntfSpeedSaveMode ].getOrCreateModeInstance( intfId )
         intfCmds = intfMode[ 'Dps.DpsIntfSpeed' ]
         intfConfig = entity.intfConfig[ intfId ]

         if intfConfig.txBandwidth:
            intfCmds.addCommand( f'metric bandwidth transmit '
                                 f'{ intfConfig.txBandwidth } Mbps' )
         if intfConfig.rxBandwidth:
            intfCmds.addCommand( f'metric bandwidth receive '
                                 f'{ intfConfig.rxBandwidth } Mbps' )
