# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import CliSave
from CliMode.Mss import ( MssStaticDeviceMode,
                          MssStaticVrfMode, MssStaticL3NextHopMode,
                          MssStaticL3PolicyMode, MssStaticRuleMode )
from CliSavePlugin.MssCliSave import MssConfigSaveMode
from CliSavePlugin.Controllerdb import ( CvxConfigMode,
                                         controllerConfigPath,
                                         getClusterName )
from MssCliLib import ( MssL3PolicyAction, mssL3ModifierVerbatim, ipProtocol )
import Tac

defaults = Tac.Value( "Mss::CliDefaults" )

MssActionCliStrToTac = {
   'drop': MssL3PolicyAction.drop,
   'forward': MssL3PolicyAction.bypass,
   'redirect': MssL3PolicyAction.redirect
}
MssActionTacToCliStr = { value: key for key, value in MssActionCliStrToTac.items() }

mssConfigPath = 'mss/config'
l3PolicyConfigPath = 'mssl3/policySourceConfig/cli'
l3PolicyV2ConfigPath = 'mssl3/policySourceConfigV2/cli'

#------------------------------------------------------------------------------
# Mss Static Service Device Mode Saver
#------------------------------------------------------------------------------
class MssStaticDeviceConfigSaveMode( MssStaticDeviceMode, CliSave.Mode ):
   def __init__( self, param ):
      MssStaticDeviceMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )
      self.addCommandSequence( 'mss.dev.%s' % param )

MssConfigSaveMode.addChildMode( MssStaticDeviceConfigSaveMode,
      after=[ 'mss.config' ] )
MssStaticDeviceConfigSaveMode.addCommandSequence( 'mss.device.config' )

class MssStaticVrfConfigSaveMode( MssStaticVrfMode, CliSave.Mode ):
   def __init__( self, param ):
      MssStaticVrfMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssStaticDeviceConfigSaveMode.addChildMode( MssStaticVrfConfigSaveMode )
MssStaticVrfConfigSaveMode.addCommandSequence( 'mss.device.vrf.config' )

class MssStaticL3NextHopConfigSaveMode( MssStaticL3NextHopMode, CliSave.Mode ):
   def __init__( self, param ):
      MssStaticL3NextHopMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssStaticVrfConfigSaveMode.addChildMode( MssStaticL3NextHopConfigSaveMode )
MssStaticL3NextHopConfigSaveMode.addCommandSequence(
                                 'mss.device.vrf.nexthop.config' )

class MssStaticL3PolicyConfigSaveMode( MssStaticL3PolicyMode, CliSave.Mode ):
   def __init__( self, param ):
      MssStaticL3PolicyMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

MssStaticVrfConfigSaveMode.addChildMode( MssStaticL3PolicyConfigSaveMode,
                                         after=[ MssStaticL3NextHopConfigSaveMode ] )
MssStaticL3PolicyConfigSaveMode.addCommandSequence( 'mss.device.vrf.policy.config' )

class MssStaticRuleConfigSaveMode( MssStaticRuleMode, CliSave.Mode ):
   def __init__( self, param ):
      self.priority = param[ 3 ]
      MssStaticRuleMode.__init__( self, ( param[ 0 ], param[ 1 ], param[ 2 ] ) )
      CliSave.Mode.__init__( self, ( param[ 0 ], param[ 1 ], param[ 2 ] ) )

   # I hate to be negative, but static rules must be installed in decreasing order.
   def instanceKey( self ):
      return -self.priority

   @classmethod
   def useInsertionOrder( cls ):
      # because `instanceKey` is overridden with -self.priority
      return True

MssStaticL3PolicyConfigSaveMode.addChildMode( MssStaticRuleConfigSaveMode,
                                           after=[ 'mss.device.vrf.policy.config' ] )
MssStaticRuleConfigSaveMode.addCommandSequence( 'mss.device.vrf.policy.rule.config' )

#------------------------------------------------------------------------------
# Helper functions
#------------------------------------------------------------------------------
def isStrictPolicyEnforcementConsistency( mssConfig ):
   return mssConfig.policyEnforcementConsistency == "strict"

def saveRoutingTable( vrfMode, vrfConfig ):
   for ip, l3Intf in sorted( vrfConfig.l3Intf.items() ):
      l3IntfMode = \
      vrfMode[ MssStaticL3NextHopConfigSaveMode ].getOrCreateModeInstance(
      ( vrfMode.deviceName, vrfMode.vrfName, ip ) )

      l3IntfCmdSeq = l3IntfMode[ 'mss.device.vrf.nexthop.config' ]
      for subnet in sorted( l3Intf.reachableSubnet ):
         l3IntfCmdSeq.addCommand( 'route %s' % subnet )

def getAddrCmdSeq( initialCmd, addresses ):
   cmd = initialCmd
   for ip in sorted( addresses ):
      cmd += ' ' + str( ip )
   return cmd

def getL4AppCmdSeq( l4Apps ):
   cmd = {}
   for l4App in sorted( l4Apps ):
      proto = ipProtocol[ l4App.proto ]
      if proto not in cmd:
         cmd[ proto ] = str( l4App.port )
      else:
         cmd[ proto ] += ' ' + str( l4App.port )
   return cmd

def saveMatch( ruleMode, match ):
   ruleCmdSeq = ruleMode[ 'mss.device.vrf.policy.rule.config' ]

   # Config source address
   if match.srcIp:
      ruleCmdSeq.addCommand( getAddrCmdSeq( 'source address', match.srcIp ) )

   # Config destination address
   if match.dstIp:
      ruleCmdSeq.addCommand( getAddrCmdSeq( 'destination address', match.dstIp ) )

   # Config source L4App
   if match.srcL4App:
      l4AppCmds = getL4AppCmdSeq( match.srcL4App )
      for proto, val in l4AppCmds.items():
         ruleCmdSeq.addCommand( 'source protocol %s port %s' %
                                ( proto, val ) )

   # Config destination L4App
   if match.dstL4App:
      l4AppCmds = getL4AppCmdSeq( match.dstL4App )
      for proto, val in l4AppCmds.items():
         ruleCmdSeq.addCommand( 'destination protocol %s port %s' %
                                ( proto, val ) )

   # Config L3 protocol
   if match.l3App:
      l3Protos = 'protocol'
      for l3App in match.l3App:
         l3Protos += ' ' + ipProtocol[ l3App ]
      ruleCmdSeq.addCommand( l3Protos )

def saveAction( ruleMode, ruleOrigin, action ):
   ruleCmdSeq = ruleMode[ 'mss.device.vrf.policy.rule.config' ]

   # Config rule action
   if mssL3ModifierVerbatim in ruleOrigin.policyModifierSet:
      ruleCmdSeq.addCommand( 'action %s' % MssActionTacToCliStr[ action ] )
   else:
      ruleCmdSeq.addCommand( 'action ip-redirect' )

def saveDirection( ruleMode, ruleOrigin ):
   ruleCmdSeq = ruleMode[ 'mss.device.vrf.policy.rule.config' ]
   forwardOnly = 'forwardOnly' in ruleOrigin.policyModifierSetCli
   reverseOnly = 'reverseOnly' in ruleOrigin.policyModifierSetCli
   # Default is empty for this config.
   if forwardOnly and reverseOnly:
      ruleCmdSeq.addCommand( 'direction forward reverse' )
   elif forwardOnly:
      ruleCmdSeq.addCommand( 'direction forward' )

def saveRule( deviceName, vrfName, priority, rule, policyMode, isV2 ):
   for ruleName, ruleOrigin in rule.origin.items():
      ruleMode = policyMode[ MssStaticRuleConfigSaveMode ].\
         getOrCreateModeInstance( ( deviceName, vrfName, ruleName, priority ) )

      if isV2:
         saveMatch( ruleMode, ruleOrigin.match )
         saveAction( ruleMode, ruleOrigin, ruleOrigin.action )
      else:
         saveMatch( ruleMode, rule.match )
         saveAction( ruleMode, ruleOrigin, rule.action )

      saveDirection( ruleMode, ruleOrigin )

def savePolicySet( vrfMode, l3Policy, isV2, options ):
   if ( not options.saveAll and not l3Policy.rule
         and 'forwardOnly' not in l3Policy.policyModifierSet ):
      return

   policyMode = \
   vrfMode[ MssStaticL3PolicyConfigSaveMode ].getOrCreateModeInstance(
      ( vrfMode.deviceName, vrfMode.vrfName ) )

   ruleCmdSeq = policyMode[ 'mss.device.vrf.policy.config' ]
   if 'forwardOnly' in l3Policy.policyModifierSet:
      ruleCmdSeq.addCommand( 'direction forward' )
   elif options.saveAll:
      ruleCmdSeq.addCommand( 'direction forward reverse' )

   for priority, rule in l3Policy.rule.items():
      saveRule( vrfMode.deviceName, vrfMode.vrfName, priority,
                rule, policyMode, isV2 )

#------------------------------------------------------------------------------
# Entity Saver
#------------------------------------------------------------------------------
@CliSave.saver( 'MssL3::ServiceDeviceSourceConfig',
                'mssl3/serviceDeviceSourceConfig/cli',
                requireMounts=( mssConfigPath, controllerConfigPath,
                                l3PolicyConfigPath, ) )
def saveMssL3StaticDeviceConfig( entity, root, requireMounts, options ):
   clusterName = getClusterName( requireMounts[ controllerConfigPath ] )
   l3PolicySourceConfig = requireMounts[ l3PolicyConfigPath ]

   if isStrictPolicyEnforcementConsistency( requireMounts[ mssConfigPath ] ):
      # Consistent enforcement enabled
      return

   for deviceName in sorted( entity.serviceDevice ):
      cvxMode = root[ CvxConfigMode ].getOrCreateModeInstance(
                CvxConfigMode.modeName( clusterName ) )
      mssMode = cvxMode[ MssConfigSaveMode ].getSingletonInstance()
      deviceMode = mssMode[ MssStaticDeviceConfigSaveMode ].getOrCreateModeInstance(
                   deviceName )
      device = entity.serviceDevice[ deviceName ]

      # Save VRF config
      for vrfName in sorted( device.vrf ):
         vrfConfig = device.vrf[ vrfName ]
         vrfMode = deviceMode[ MssStaticVrfConfigSaveMode ].getOrCreateModeInstance(
                   ( deviceName, vrfName ) )

         # Save routing table
         saveRoutingTable( vrfMode, vrfConfig )

         # Save L3 policy config
         if deviceName in l3PolicySourceConfig.policySet:
            l3PolicyConfig = l3PolicySourceConfig.policySet[ deviceName ]
            if vrfName in l3PolicyConfig.policy:
               l3Policy = l3PolicyConfig.policy[ vrfName ]
               savePolicySet( vrfMode, l3Policy, False, options )

@CliSave.saver( 'MssL3V2::ServiceDeviceSourceConfig',
                'mssl3/serviceDeviceSourceConfigV2/cli',
                requireMounts=( mssConfigPath, controllerConfigPath,
                                l3PolicyV2ConfigPath ) )
def saveMssL3V2StaticDeviceConfig( entity, root, requireMounts, options ):
   clusterName = getClusterName( requireMounts[ controllerConfigPath ] )
   l3PolicySourceV2Config = requireMounts[ l3PolicyV2ConfigPath ]

   if not isStrictPolicyEnforcementConsistency( requireMounts[ mssConfigPath ] ):
      # Consistent enforcement is disabled
      return

   for mssDeviceName in sorted( entity.serviceDevice ):
      deviceName = mssDeviceName.getPhyInstanceName()
      cvxMode = root[ CvxConfigMode ].getOrCreateModeInstance(
                CvxConfigMode.modeName( clusterName ) )
      mssMode = cvxMode[ MssConfigSaveMode ].getSingletonInstance()
      deviceMode = mssMode[ MssStaticDeviceConfigSaveMode ].getOrCreateModeInstance(
                   deviceName )
      deviceCmdSeq = deviceMode[ 'mss.device.config' ]
      device = entity.serviceDevice[ mssDeviceName ]

      # Save traffic inspection
      if options.saveAll or device.trafficInspection != defaults.trafficInspection:
         cmd = "traffic inspection local"
         if device.trafficInspection.outbound:
            cmd += " outbound"
         deviceCmdSeq.addCommand( cmd )

      # Save VRF config
      for netVrfName in sorted( device.netVrf ):
         vrfConfig = device.netVrf[ netVrfName ]
         vrfMode = deviceMode[ MssStaticVrfConfigSaveMode ].getOrCreateModeInstance(
                   ( deviceName, netVrfName ) )

         # Save routing table
         saveRoutingTable( vrfMode, vrfConfig )

         # Save L3 policy config
         if mssDeviceName in l3PolicySourceV2Config.policySet:
            l3PolicyConfig = l3PolicySourceV2Config.policySet[ mssDeviceName ]
            if netVrfName in l3PolicyConfig.policy:
               l3Policy = l3PolicyConfig.policy[ netVrfName ]
               savePolicySet( vrfMode, l3Policy, True, options )
