#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliMode.Firewall
import CliMode.Classification
import CliSave
from CliSavePlugin.IntfCliSave import IntfConfigMode

class FirewallConfigMode( CliMode.Firewall.FirewallConfigMode,
                          CliSave.Mode ):
   def __init__( self, param ):
      CliMode.Firewall.FirewallConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

CliSave.GlobalConfigMode.addChildMode( FirewallConfigMode, after=[ IntfConfigMode ] )
FirewallConfigMode.addCommandSequence( 'firewall' )

@CliSave.saver( 'Firewall::Config', 'firewall/config/cli' )
def saveRouterFirewallConfig( entity, root, requireMounts, options ):
   if entity.enabled:
      fwMode = root[ FirewallConfigMode ].getSingletonInstance()
      fwMode[ 'firewall' ].addCommand( "no shutdown" )
      if entity.defaultPolicyDrop == 'enable':
         fwMode[ 'firewall' ].addCommand(
            "segment policy policy-drop-all default" )
      elif entity.defaultPolicyDrop == 'disable':
         fwMode[ 'firewall' ].addCommand(
            "no segment policy policy-drop-all default" )

      if entity.forwardType.routed:
         fwMode[ 'firewall' ].addCommand( "forwarding-type routed" )

   elif options.saveAll:
      fwMode = root[ FirewallConfigMode ].getSingletonInstance()
      fwMode[ 'firewall' ].addCommand( "shutdown" )
      if entity.defaultPolicyDrop == 'enable':
         fwMode[ 'firewall' ].addCommand(
            "segment policy policy-drop-all default" )
      elif entity.defaultPolicyDrop == 'disable':
         fwMode[ 'firewall' ].addCommand(
            "no segment policy policy-drop-all default" )

      if entity.forwardType.routed:
         fwMode[ 'firewall' ].addCommand( "forwarding-type routed" )
      elif not entity.forwardType:
         fwMode[ 'firewall' ].addCommand( "no forwarding-type" )

class FirewallPolicyConfigMode( CliMode.Firewall.FirewallPolicyConfigMode,
                                CliSave.Mode ):
   def __init__( self, param ):
      CliMode.Firewall.FirewallPolicyConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

FirewallConfigMode.addChildMode( FirewallPolicyConfigMode, after=[ 'firewall' ] )
FirewallPolicyConfigMode.addCommandSequence( 'policy' )

@CliSave.saver( 'Firewall::Config', 'firewall/config/cli' )
def savePolicyConfig( entity, root, requireMounts, options ):
   for policyName in entity.policy:
      if entity.policy[ policyName ].readonly:
         continue
      fwMode = root[ FirewallConfigMode ].getSingletonInstance()
      cmds = fwMode[ FirewallPolicyConfigMode
                             ].getOrCreateModeInstance( policyName )[ 'policy' ]
      for seq in entity.policy[ policyName ].rule:
         cmd = '%s ' % str( seq )
         service = entity.policy[ policyName ].rule[ seq ].serviceName
         action = entity.policy[ policyName ].rule[ seq ].action
         nexthop = entity.policy[ policyName ].rule[ seq ].nexthop
         log = entity.policy[ policyName ].rule[ seq ].log
         if action == 'deny':
            cmd += 'application %s action drop' % service
         elif action == 'statelessDeny':
            cmd += 'application %s action drop stateless' % service
         elif action == 'permit':
            cmd += 'application %s action forward' % service
         elif action == 'statelessPermit':
            cmd += 'application %s action forward stateless' % service
         else:
            assert action == 'statelessRedirect'
            cmd += 'application %s action redirect next-hop %s '\
               'stateless' % ( service, nexthop )
         if log:
            cmd += ' log'
         cmds.addCommand( cmd )

class FirewallVrfConfigMode( CliMode.Firewall.FirewallVrfConfigMode,
                             CliSave.Mode ):
   def __init__( self, param ):
      CliMode.Firewall.FirewallVrfConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

FirewallConfigMode.addChildMode( FirewallVrfConfigMode, after=[ 'firewall' ] )
FirewallVrfConfigMode.addCommandSequence( 'vrf' )

class FirewallVrfSegmentConfigMode( CliMode.Firewall.FirewallVrfSegmentConfigMode,
                                    CliSave.Mode ):
   def __init__( self, param ):
      _, segmentName = param
      CliMode.Firewall.FirewallVrfSegmentConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, segmentName )

FirewallVrfConfigMode.addChildMode( FirewallVrfSegmentConfigMode )
FirewallVrfSegmentConfigMode.addCommandSequence( 'segment' )

class FirewallSegmentConfigMode( CliMode.Firewall.FirewallSegmentConfigMode,
                                 CliSave.Mode ):
   def __init__( self, param ):
      CliMode.Firewall.FirewallSegmentConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

FirewallConfigMode.addChildMode( FirewallSegmentConfigMode )
FirewallSegmentConfigMode.addCommandSequence( 'segment' )

class ClassMapFirewallVrfConfigMode( CliMode.Firewall.ClassMapFirewallVrfConfigMode,
                                     CliSave.Mode ):

   def __init__( self, param ):
      CliMode.Firewall.ClassMapFirewallVrfConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

FirewallVrfSegmentConfigMode.addChildMode( ClassMapFirewallVrfConfigMode )
ClassMapFirewallVrfConfigMode.addCommandSequence( 'segmentDef' )

class ClassMapFirewallConfigMode( CliMode.Firewall.ClassMapFirewallConfigMode,
                                  CliSave.Mode ):

   def __init__( self, param ):
      # param is segment name
      CliMode.Firewall.ClassMapFirewallConfigMode.__init__( self, ( None, param ) )
      CliSave.Mode.__init__( self, param )

FirewallSegmentConfigMode.addChildMode( ClassMapFirewallConfigMode )
ClassMapFirewallConfigMode.addCommandSequence( 'segmentDef' )

@CliSave.saver( 'Firewall::Config', 'firewall/config/cli' )
def saveCmapVrfConfig( entity, root, requireMounts, options ):
   for cmapName, cmap in entity.classMap.items():
      fwMode = root[ FirewallConfigMode ].getSingletonInstance()
      if '__VRF_' not in cmapName:
         continue
      segIndex = cmapName.find( '_segment_' )
      vrfName = cmapName[ len( '__VRF_' ) : segIndex ]
      segmentName = cmapName[ segIndex + len( '_segment_' ) : ]
      param = ( vrfName, segmentName )
      firewallMode = fwMode[ FirewallVrfConfigMode
                             ].getOrCreateModeInstance( vrfName )
      segMode = firewallMode[ FirewallVrfSegmentConfigMode
                              ].getOrCreateModeInstance( ( vrfName, segmentName ) )
      if cmap.intfVlanData:
         cmds = segMode[ ClassMapFirewallVrfConfigMode ].\
            getOrCreateModeInstance( param )[ 'segmentDef' ]
         cmd = 'match interface ' + ' '.join( cmap.intfVlanData )
         cmds.addCommand( cmd )

      if vrfName in entity.vrf:
         if segmentName in entity.vrf[ vrfName ].segmentDir.segment:
            cmds = segMode[ ClassMapFirewallVrfConfigMode
                   ].getOrCreateModeInstance( param )[ 'segmentDef' ]
            if entity.vrf[ vrfName ].segmentDir.segment[
                  segmentName ].ipv4PrefixList:
               seg = entity.vrf[ vrfName ].segmentDir.segment[
                     segmentName ]
               ipv4Prefix = seg.ipv4PrefixList
               if seg.usePrefixListForIpv4:
                  cmd = 'match covered prefix-list ipv4 %s' % ipv4Prefix
               else:
                  cmd = 'match prefix-ipv4 %s' % ipv4Prefix
               cmds.addCommand( cmd )
            if entity.vrf[ vrfName ].segmentDir.segment[
                  segmentName ].ipv6PrefixList:
               seg = entity.vrf[ vrfName ].segmentDir.segment[
                     segmentName ]
               ipv6Prefix = seg.ipv6PrefixList
               if seg.usePrefixListForIpv6:
                  cmd = 'match covered prefix-list ipv6 %s' % ipv6Prefix
               else:
                  cmd = 'match prefix-ipv6 %s' % ipv6Prefix
               cmds.addCommand( cmd )

@CliSave.saver( 'Firewall::Config', 'firewall/config/cli' )
def saveCmapConfig( entity, root, requireMounts, options ):
   def intfVlanRangesToStrings( intfVlanRanges ):
      vlanRanges = [ ( x.vlanBegin, x.vlanEnd )
                      for x in intfVlanRanges.vlanRange ]
      vlanRangeStrings = []
      for vlanRange in sorted( vlanRanges, key=lambda x: x[ 0 ] ):
         if vlanRange[ 0 ] == vlanRange[ 1 ]:
            vlanRangeStrings.append( str( vlanRange[ 0 ] ) )
         else:
            vlanRangeStrings.append( '%d-%d' %
                                     ( vlanRange[ 0 ], vlanRange[ 1 ] ) )
      return vlanRangeStrings

   for cmapName, cmap in entity.classMap.items():
      firewallMode = root[ FirewallConfigMode ].getSingletonInstance()
      segIndex = cmapName.find( '__segment_' )
      if segIndex == -1:
         continue
      segmentName = cmapName[ segIndex + len( '__segment_' ) : ]
      segMode = firewallMode[ FirewallSegmentConfigMode
                              ].getOrCreateModeInstance( segmentName )
      cmds = segMode[ ClassMapFirewallConfigMode ].\
         getOrCreateModeInstance( segmentName )[ 'segmentDef' ]
      if cmap.intfVlanData:

         # All interface-only match criteria to be saved in one line
         # All interface-vlan match criteria to be saved in one line per interface
         intfOnlyCmd = ''
         intfVlanCmds = []
         for intf, intfVlan in cmap.intfVlanData.items():
            if intfVlan.vlanRange:
               vlanRangeStrings = intfVlanRangesToStrings( intfVlan )
               cmd = 'match interface {} vlan {}'.format(
                      intf, ', '.join( vlanRangeStrings ) )
               intfVlanCmds.append( cmd )
            else:
               intfOnlyCmd += ' %s' % intf

         if intfOnlyCmd:
            intfOnlyCmd = 'match interface' + intfOnlyCmd
            cmds.addCommand( intfOnlyCmd )
         for cmd in intfVlanCmds:
            cmds.addCommand( cmd )

class VrfSegmentPoliciesConfigMode( CliMode.Firewall.VrfSegmentPolicyMode,
                                 CliSave.Mode ):
   def __init__( self, param ):
      CliMode.Firewall.VrfSegmentPolicyMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

FirewallVrfSegmentConfigMode.addChildMode( VrfSegmentPoliciesConfigMode )
VrfSegmentPoliciesConfigMode.addCommandSequence( 'segPolicies' )

class SegmentPoliciesConfigMode( CliMode.Firewall.SegmentPolicyMode,
                                 CliSave.Mode ):
   def __init__( self, param ):
      # param is segmentName
      CliMode.Firewall.SegmentPolicyMode.__init__( self, ( None, param ) )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

FirewallSegmentConfigMode.addChildMode( SegmentPoliciesConfigMode )
SegmentPoliciesConfigMode.addCommandSequence( 'segPolicies' )

@CliSave.saver( 'Firewall::Config', 'firewall/config/cli' )
def saveVrfSegmentPolicies( entity, root, requireMounts, options ):
   for vrfName in entity.vrf:
      fwMode = root[ FirewallConfigMode ].getSingletonInstance()
      firewallMode = fwMode[ FirewallVrfConfigMode
                             ].getOrCreateModeInstance( vrfName )
      for segmentName in entity.vrf[ vrfName ].segmentDir.segment:
         segMode = firewallMode[ FirewallVrfSegmentConfigMode
                                ].getOrCreateModeInstance( ( vrfName, segmentName ) )
         segment = entity.vrf[ vrfName ].segmentDir.segment[ segmentName ]
         param = ( vrfName, segmentName )
         cmds = segMode[ VrfSegmentPoliciesConfigMode ].\
            getOrCreateModeInstance( param )[ 'segPolicies' ]
         if segment.policy:
            for fromSegment in segment.policy:
               cmd = 'from {} policy {}'.format( fromSegment,
                                                 segment.policy[ fromSegment ] )
               cmds.addCommand( cmd )
         if segment.fallbackPolicy:
            cmd = 'fallback policy %s' % ( segment.fallbackPolicy )
            cmds.addCommand( cmd )

@CliSave.saver( 'Firewall::Config', 'firewall/config/cli' )
def saveSegmentPolicies( entity, root, requireMounts, options ):
   if not entity.segmentDir or not entity.segmentDir.segment:
      return
   firewallMode = root[ FirewallConfigMode ].getSingletonInstance()
   for segmentName in entity.segmentDir.segment:
      segMode = firewallMode[ FirewallSegmentConfigMode
                             ].getOrCreateModeInstance( segmentName )
      segment = entity.segmentDir.segment[ segmentName ]
      cmds = segMode[ SegmentPoliciesConfigMode ].\
         getOrCreateModeInstance( segmentName )[ 'segPolicies' ]
      if segment.policy:
         for fromSegment in segment.policy:
            cmd = 'from {} policy {}'.format( fromSegment,
                                              segment.policy[ fromSegment ] )
            cmds.addCommand( cmd )
      if segment.fallbackPolicy:
         cmd = 'fallback policy %s' % ( segment.fallbackPolicy )
         cmds.addCommand( cmd )
