# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import weakref
import Tac
from TypeFuture import TacLazyType

constants = TacLazyType( 'Dps::DpsConstants' )
pathGroupPriority = TacLazyType( 'Avt::PathGroupPriority' )

class RouterPathSelectionContext:
   def __init__( self, config ):
      self.config = config

   def setMtuDiscInterval( self, mtuInterval ):
      self.config.mtuDiscInterval = mtuInterval or constants.mtuDiscDefaultInterval

class DpsPathGroupContext:
   def __init__( self, config, status, peerStatus, pathGroupName ):
      self.config = config
      self.status = status
      self.peerStatus = peerStatus
      self.pathGroupName_ = pathGroupName
      self.mode = None
      self.pgCfg = None
      self.routerIp = None
      self.routerIpCfg = None
      self.pathGroupId_ = 0

   def checkIfIntfOrIPInUse( self, intfName=None, ipAddress=None ):
      # traverse through each path group and check if interface or
      # ip is already in use
      for pgName, pgCfg in self.config.pathGroupConfig.items():
         if pgName == self.pathGroupName_:
            continue
         if intfName and intfName in pgCfg.localIntf:
            return True
         if ipAddress and ipAddress in pgCfg.localIp:
            return True
      return False

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def pathGroupName( self ):
      return self.pathGroupName_

   def pathGroupId( self ):
      return self.pathGroupId_

   def addOrRemovePathGroup( self, pgName, pgId, add=True ):
      pathGroups = self.config.pathGroupConfig
      pgCfg = pathGroups.get( pgName )
      self.pathGroupName_ = pgName
      self.pathGroupId_ = pgId
      if add:
         if not pgCfg:
            pgCfg = self.config.newPathGroupConfig( pgName )
         pgCfg.pathGroupId = pgId
      else:
         del self.config.pathGroupConfig[ pgName ]
         pgCfg = None
      self.pgCfg = pgCfg

   def copyEditPathGroup( self ):
      pass

   def currentPgCfg( self ):
      return self.pgCfg

   def addOrRemoveRouterIp( self, routerIp, add=True ):
      remoteViaConfig = getattr( self.pgCfg, "remoteViaConfig" )
      router = Tac.Value( "Arnet::IpGenAddr", routerIp )
      if add:
         remoteViaConfig.newMember( router )
         self.routerIpCfg = remoteViaConfig[ router ]
      else:

         del remoteViaConfig[ router ]
      self.routerIp = routerIp

   def currentRouterIp( self ):
      return self.routerIp

   def addOrRemoveDynamicPeer( self, add=True ):
      self.pgCfg.remoteDynamic = ( add )
      if not add:
         self.pgCfg.preferLocalIp = False
         self.pgCfg.dynamicPeerIpsec = \
               Tac.enumValue( 'Dps::IpsecConfigState', 'ipsecDefault' )

   def addOrRemoveIntf( self, intfName, vrf=None, publicIp=None, add=True ):
      intf = Tac.Value( "Arnet::IntfId", intfName )
      if add:
         if publicIp:
            pubIp = Tac.Value( "Arnet::IpGenAddr", publicIp )
         else:
            pubIp = Tac.Value( "Arnet::IpGenAddr" )
         if self.checkIfIntfOrIPInUse( intfName=intfName ):
            return True
         intfConfig = Tac.Type( "Dps::LocalIntfConfig" )()
         intfConfig.publicIp = pubIp
         intfConfig.vrfName = vrf
         self.pgCfg.localIntf[ intf ] = intfConfig
      else:
         if intfName in self.pgCfg.intfStunConfig:
            self.pgCfg.intfStunConfig[ intfName ].serverProfile.clear()
            del self.pgCfg.intfStunConfig[ intfName ]
         del self.pgCfg.localIntf[ intf ]
      return None

   def addOrRemoveLocalIp( self, ip, publicIp=None, add=True ):
      ipEntry = Tac.Value( "Arnet::IpGenAddr", ip )
      if add:
         if publicIp:
            pubIp = Tac.Value( "Arnet::IpGenAddr", publicIp )
         else:
            pubIp = Tac.Value( "Arnet::IpGenAddr" )
         if self.checkIfIntfOrIPInUse( ipAddress=ipEntry ):
            return True
         self.pgCfg.localIp[ ipEntry ] = pubIp
      else:
         if ipEntry in self.pgCfg.ipStunConfig:
            self.pgCfg.ipStunConfig[ ipEntry ].serverProfile.clear()
         del self.pgCfg.localIp[ ipEntry ]
      return None

   def addOrRemovePathGroupIpsec( self, profileName, add=True ):
      if add:
         self.pgCfg.ipsecProfile = profileName
      else:
         self.pgCfg.ipsecProfile = ""

   def addOrRemoveRemoteRouterIpsec( self, profileName, add=True ):
      if add:
         self.routerIpCfg.ipsecProfile = profileName
      else:
         self.routerIpCfg.ipsecProfile = ""

   def setMss( self, mss ):
      self.pgCfg.mssEgress = mss

   def setMtu( self, mtu ):
      self.pgCfg.mtu = mtu

   def setMtuDiscInterval( self, interval ):
      self.pgCfg.mtuDiscInterval = interval or 0

   def setItsInterval( self, keepaliveInterval, scale=5 ):
      self.pgCfg.keepaliveInterval = keepaliveInterval
      self.pgCfg.feedbackScale = scale

   def pathGroupImportExists( self, remotePg, localPg ):
      # Returns True if the same import statement is present in another path-group
      for pgCfg in self.config.pathGroupConfig.values():
         pathViaPair = Tac.Value( "Dps::DpsPathViaPair", remotePg, localPg )
         if ( pathViaPair in pgCfg.pathViaPairSet ) and \
            ( pgCfg.name != self.pathGroupName_ ):
            return True
      return False

   def setPathGroupImport( self, remotePg, localPg ):
      pathViaPair = Tac.Value( "Dps::DpsPathViaPair", remotePg, localPg )
      self.pgCfg.pathViaPairSet.add( pathViaPair )

   def removePathGroupImport( self, remotePg, localPg ):
      pathViaPair = Tac.Value( "Dps::DpsPathViaPair", remotePg, localPg )
      self.pgCfg.pathViaPairSet.remove( pathViaPair )

   def addOrRemoveRouterVia( self, viaIp, add=True ):
      remoteEncap = getattr( self.routerIpCfg, "remoteEncap" )
      via = Tac.Value( "Arnet::IpGenAddr", viaIp )
      if add:
         remoteEncap[ via ] = True
      else:
         del remoteEncap[ via ]

   def checkPeerName( self, peerName ):
      # Ensure that the peer does not have a different peerName configured elsewhere
      peerStatusEntry = self.peerStatus.peerStatusEntry
      router = Tac.Value( "Arnet::IpGenAddr", self.routerIp )
      if ( router not in peerStatusEntry ) or \
         ( not peerStatusEntry[ router ].peerName ) or \
         ( peerStatusEntry[ router ].peerName == peerName ) or \
         ( not peerStatusEntry[ router ].pathGroupWithPeerName ) or \
         ( len( peerStatusEntry[ router ].pathGroupWithPeerName ) == 1 and
           self.pathGroupName_ in peerStatusEntry[ router ].pathGroupWithPeerName ):
         return None
      else:
         return ( router.stringValue, peerStatusEntry[ router ].peerName )

   def addOrRemovePeerName( self, peerName, add=True ):
      if add:
         self.routerIpCfg.peerName = peerName
      else:
         self.routerIpCfg.peerName = ""

   def getPathGroupPeers( self ):
      peers = {}
      peers[ 'static' ] = list( self.pgCfg.remoteViaConfig )
      if self.pgCfg.remoteDynamic:
         peers[ 'dynamic' ] = True
      return peers

   def isPathGroupFlowAssignmentLan( self ):
      return self.pgCfg.flowAssignmentLan

   def checkFlowAssignmentLanPathGroup( self ):
      pathGroups = self.config.pathGroupConfig
      for pgName, pg in pathGroups.items():
         if pg.flowAssignmentLan:
            return pgName
      return None

   def addOrRemoveFlowAssignmentLan( self, add=True ):
      self.pgCfg.flowAssignmentLan = add

class DpsPolicyContext:
   def __init__( self, config, status, policyName ):
      self.config = config
      self.status = status
      self.policyName_ = policyName
      self.mode = None
      self.policyCfg = None
      self.ruleKey = None
      self.ruleKeyCfg = None
      self.appProfile = None
      self.defaultRuleCfgd = None

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def policyName( self ):
      return self.policyName_

   def addOrRemovePolicy( self, policyName, add=True ):
      policies = self.config.policyConfig
      policyCfg = policies.get( policyName )
      if add:
         if not policyCfg:
            policyCfg = self.config.newPolicyConfig( policyName )
            self.config.wanTEPolicyConfig.addMember( policyCfg )
      else:
         del self.config.policyConfig[ policyName ]
         del self.config.wanTEPolicyConfig[ policyName ]
         policyCfg = None
      self.policyCfg = policyCfg

   def copyEditPolicyCfg( self ):
      pass

   def currentPolicyCfg( self ):
      return self.policyCfg

   def currentRuleKey( self ):
      return self.ruleKey

   def currentAppProfile( self ):
      return self.appProfile

   def currentDefaultRuleCfgd( self ):
      return self.defaultRuleCfgd

   def addOrRemoveRuleKey( self, ruleKey, appProfile, add=True ):
      appProfilePolicyRuleList = getattr( self.policyCfg,
            "appProfilePolicyRuleList" )
      if add:
         appProfilePolicyRuleList.newMember( ruleKey )
         appProfilePolicyRuleList[ ruleKey ].appProfileName = appProfile
         self.ruleKey = ruleKey
         self.appProfile = appProfile
         self.ruleKeyCfg = appProfilePolicyRuleList[ ruleKey ]
         self.defaultRuleCfgd = False
      else:
         del appProfilePolicyRuleList[ ruleKey ]
         self.ruleKey = None
         self.appProfile = None
         self.ruleKeyCfg = None
      self.defaultRuleCfgd = False

   def addOrRemoveDefaultRule( self, add=True ):
      if add:
         self.policyCfg.defaultRuleCfgd = True
         self.defaultRuleCfgd = True
      else:
         self.policyCfg.defaultRuleCfgd = False
         self.policyCfg.defaultActionName = ""
         self.defaultRuleCfgd = False

   def setLbGrpName( self, lbGrpName ):
      if self.defaultRuleCfgd:
         self.policyCfg.defaultActionName = lbGrpName
      else:
         self.ruleKeyCfg.actionName = lbGrpName

class DpsLoadBalanceProfileContext:
   def __init__( self, config, status, profileName ):
      self.config = config
      self.status = status
      self.profileName_ = profileName
      self.mode = None
      self.profile = None

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def profileName( self ):
      return self.profileName_

   def addProfile( self, profileName ):
      profile = self.config.loadBalanceProfile.get( profileName )
      if not profile:
         profile = self.config.newLoadBalanceProfile( profileName )
      self.profile = profile

   def delProfile( self, profileName ):
      profile = self.config.loadBalanceProfile.get( profileName )
      if profile:
         del self.config.loadBalanceProfile[ profileName ]
      self.profile = None

   def addPathGroup( self, pathGroupName, priority ):
      profile = self.currentProfile()
      if profile:
         profile.pathGroupPriority[ pathGroupName ] = pathGroupPriority( priority )

   def delPathGroup( self, pathGroupName ):
      profile = self.currentProfile()
      if profile:
         if pathGroupName in profile.pathGroupPriority:
            del profile.pathGroupPriority[ pathGroupName ]

   def setLatency( self, latency ):
      profile = self.currentProfile()
      profile.latency = latency

   def setJitter( self, jitter ):
      profile = self.currentProfile()
      profile.jitter = jitter

   def setLossRate( self, lossRate ):
      profile = self.currentProfile()
      profile.lossRate = lossRate

   def setHopCountLowest( self, enable=False ):
      profile = self.currentProfile()
      profile.hopCountLowest = enable

   def currentProfile( self ):
      return self.profile

class DpsVrfConfigContext:
   def __init__( self, config, status, vrfName ):
      self.config = config
      self.status = status
      self.vrfName_ = vrfName
      self.mode = None
      self.vrfConfig = None

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def vrfName( self ):
      return self.vrfName_

   def currentVrfCfg( self ):
      return self.vrfConfig

   def addVrfConfig( self, vrfName ):
      vrfCfg = self.config.vrfConfig.get( vrfName )
      if not vrfCfg:
         vrfCfg = self.config.newVrfConfig( vrfName )
         self.config.wanTEVrfConfig.addMember( vrfCfg )
      self.vrfConfig = vrfCfg

   def delVrfConfig( self, vrfName ):
      vrfCfg = self.config.vrfConfig.get( vrfName )
      if vrfCfg:
         del self.config.vrfConfig[ vrfName ]
         del self.config.wanTEVrfConfig[ vrfName ]
      self.vrfConfig = None

   def setPolicy( self, policyName ):
      self.vrfConfig.policyName = policyName

class DpsEncapConfigContext:
   def __init__( self, config, status ):
      self.config = config
      self.status = status
      self.mode = None

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def currentUdpPortCfg( self ):
      return self.config.udpPortConfig()

   def setUdpPort( self, portNum ):
      self.config.udpPortConfig = portNum

class DpsPeerDynamicStunConfigContext:
   def __init__( self, config, status ):
      self.config = config
      self.status = status
      self.mode = None

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def addOrRemovePeerStunConfig( self, add=True ):
      self.config.peerDynamicStun = add

class DpsPathGroupStunConfigContext:
   def __init__( self, config, status, pgName, key, vrf, pubIp, stunProfiles,
                 keyType ):
      self.config = config
      self.status = status
      self.pgName = pgName
      self.key = key
      self.vrf = vrf
      self.publicIp = pubIp
      self.stunProfiles = stunProfiles
      self.mode = None
      self.keyType = keyType

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def getStunProfileEntry( self ):
      keyType = "Arnet::IpGenAddr" if self.keyType == 'ip' else "Arnet::IntfId"
      keyVal = Tac.Value( keyType, self.key )
      return self.stunProfiles.newMember( keyVal )

   def addOrRemoveStunProfile( self, profileSet=None, add=True ):
      stunProfileEntry = self.getStunProfileEntry()
      if add:
         newProfileSet = profileSet
         currentProfile = set( stunProfileEntry.serverProfile )
         # add profiles which are required
         addedProfile = newProfileSet.difference( currentProfile )
         for profile in addedProfile:
            stunProfileEntry.serverProfile.add( profile )
         # drop profiles which are no longer needed
         discardedProfiles = currentProfile.difference( newProfileSet )
         for profile in discardedProfiles:
            stunProfileEntry.serverProfile.remove( profile )
      else:
         if not profileSet:
            stunProfileEntry.serverProfile.clear()
            return
         for profile in profileSet:
            stunProfileEntry.serverProfile.remove( profile )

class DpsPeerDynamicConfigContext:
   def __init__( self, config, status, pgName ):
      self.config = config
      self.status = status
      self.mode = None
      self.pgName = pgName
      pathGroups = self.config.pathGroupConfig
      self.pgCfg = pathGroups.get( self.pgName )

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def updateDynamicPeerIpsec( self, disable, default=False ):
      if default:
         ipsecState = 'ipsecDefault'
      elif disable:
         ipsecState = 'ipsecDisabled'
      else:
         ipsecState = 'ipsecEnabled'
      self.pgCfg.dynamicPeerIpsec = \
            Tac.enumValue( 'Dps::IpsecConfigState', ipsecState )

   def addOrRemovePreferLocalIp( self, add=True ):
      self.pgCfg.preferLocalIp = ( add )

class DpsIntfContext:
   def __init__( self, config, status, intfName ):
      self.intfName_ = intfName
      self.config = config
      self.status = status
      self.intfConfig = None
      self.mode = None

   def modeIs( self, mode ):
      self.mode = weakref.ref( mode ) if mode else None

   def currentIntfConfig( self ):
      return self.intfConfig

   def intfName( self ):
      return self.intfName_

   def addOrRemoveIntf( self, intfName, add=True ):
      intfId = Tac.Value( "Arnet::IntfId", intfName )
      intfConfig = self.config.intfConfig
      self.intfName_ = intfName
      if add:
         if intfId not in intfConfig:
            intfConfig.addMember( Tac.Value( "Dps::IntfConfig", intfId, 0, 0 ) )
         self.intfConfig = intfConfig[ intfId ]
      else:
         del intfConfig[ intfId ]

   def setTxBandwidth( self, txBandwidth ):
      intfId = self.intfConfig.intfId
      rxBandwidth = self.intfConfig.rxBandwidth
      self.intfConfig = Tac.Value( "Dps::IntfConfig", intfId, txBandwidth,
              rxBandwidth )
      self.config.intfConfig.addMember( self.intfConfig )

   def setRxBandwidth( self, rxBandwidth ):
      intfId = self.intfConfig.intfId
      txBandwidth = self.intfConfig.txBandwidth
      self.intfConfig = Tac.Value( "Dps::IntfConfig", intfId, txBandwidth,
              rxBandwidth )
      self.config.intfConfig.addMember( self.intfConfig )
