# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os

from Arnet import IpGenAddr
from CliDynamicSymbol import CliDynamicPlugin
import CliGlobal
from CliPlugin.FirewallCliLib import (
      clearFirewallCountersHook, showFirewallSessionHook )
from CliPlugin.VrfCli import CliVrfMapper
from CommonGuards import ssoStandbyGuard
import LazyMount
import Tac
from TypeFuture import TacLazyType

FirewallModels = CliDynamicPlugin( "FirewallModels" )
FwCountersDir = TacLazyType( "Firewall::CountersDir" )
FwProtocol = TacLazyType( "Firewall::Protocol" )
UniqueId = TacLazyType( 'Ark::UniqueId' )

gv = CliGlobal.CliGlobal(
   firewallConfig=None,
   firewallCounters=None,
   firewallCountersRequest=None,
   firewallCountersSnapshot=None,
   hwCapability=None,
   hwConfig=None,
   hwStatus=None,
   matchListCapability=None,
)

def getDefaultVrf( mode ):
   vrfMap = CliVrfMapper( None, os.getuid(), os.getgid() )
   return vrfMap.lookupCliModeVrf( mode, None )

# --------------------------------------------------------------------------
# 'show segment-security [ VRF ] [ segment <segment-name> ]' CLI
# --------------------------------------------------------------------------
def showSegmentSecurityHandler( mode, args ):
   firewallModel = FirewallModels.FirewallInfo()
   vrfName = args.get( 'VRF' )
   segmentNameArg = args.get( 'SEGMENT_NAME' )
   vrfModel = None
   if ( vrfName is None ) and not gv.hwCapability.intfVlanSupported:
      vrfName = getDefaultVrf( mode )
   if vrfName is not None:
      vrfModel = firewallModel.Vrf()
      vrfModel.ipv6PrefixSupported = gv.matchListCapability.ipv6PrefixSupported
      vrfModel.enabled = gv.firewallConfig.enabled
      # Check if the specified VRF exists. If not then bailout.
      firewallModel.vrfs[ vrfName ] = vrfModel
      if ( vrfName not in gv.hwConfig.vrf ) or (
            vrfName not in gv.firewallConfig.vrf ):
         vrfModel.vrfExist = False
         return firewallModel
      else:
         vrfModel.vrfExist = True

   if not gv.firewallConfig.enabled:
      # Output is not displayed if firewall is not enabled.
      return firewallModel

   def populateSegmentModel( segmentName, segment ):
      segmentModel = FirewallModels.Segment()
      segmentModel.classMap = segment.className
      for intfVlanRanges in segment.intfVlanRangeSet.values():
         if intfVlanRanges.vlanRange:
            intfVlanModel = FirewallModels.InterfaceVlanRange()
            intfVlanModel.interface = intfVlanRanges.intfId
            for vlanRange in intfVlanRanges.vlanRange:
               vlanRangeModel = FirewallModels.VlanRangeModel()
               vlanRangeModel.startVlanId = vlanRange.vlanBegin
               vlanRangeModel.endVlanId = vlanRange.vlanEnd
               intfVlanModel.vlanRanges.append( vlanRangeModel )
            segmentModel.interfaceVlans.append( intfVlanModel )
         else:
            segmentModel.interfaces.append( intfVlanRanges.intfId )

      if ( vrfModel and
           ( ( vrfConfig := gv.firewallConfig.vrf.get( vrfName ) ) is not None ) and
           ( ( segConfig := vrfConfig.segment.get( segmentName ) ) is not None ) ):
         segmentModel.ipv4PrefixListName = segConfig.ipv4PrefixList
         segmentModel.ipv6PrefixListName = segConfig.ipv6PrefixList
      else:
         segmentModel.ipv4PrefixListName = ''
         segmentModel.ipv6PrefixListName = ''
      # NB: "-" in fromSegments represents fallback policy
      # but "-" is a legal name. Fallback policy will be overwritten
      if segment.fallbackPolicy:
         segmentModel.fromSegments[ '-' ] = segment.fallbackPolicy.name
      for fromSegment in segment.fromSegment.values():
         fromSegmentName = fromSegment.segment.name
         policyName = fromSegment.policy.name
         segmentModel.fromSegments[ fromSegmentName ] = policyName
      if vrfModel:
         vrfModel.segments[ segmentName ] = segmentModel
      else:
         firewallModel.segments[ segmentName ] = segmentModel

   def populateSegDirModel( segDirHwConfig ):
      if segDirHwConfig is None:
         return
      # populate vrfModel
      if not segmentNameArg:
         for segmentName, segment in segDirHwConfig.segment.items():
            populateSegmentModel( segmentName, segment )
      elif ( segment := segDirHwConfig.segment.get( segmentNameArg ) ) is not None:
         # populate vrfModel only for specified segment
         populateSegmentModel( segmentNameArg, segment )

   if vrfModel:
      populateSegDirModel( gv.hwConfig.vrf.get( vrfName ) )
   else:
      populateSegDirModel( gv.hwConfig.l2Seg )
   return firewallModel

# --------------------------------------------------------------------------
# 'show segment-security [ VRF ]
#     to ( ( DST_IP [ from SRC_IP ] ) | ( DST_IPV6 [ from SRC_IPV6 ] ) )' CLI
# --------------------------------------------------------------------------
def showSegmentSecurityPolicyToHandler( mode, args ):
   # Given 2 IP addresses, determine their segments and the policy between them.

   fwPolicyToInfoModel = FirewallModels.FwPolicyToInfo()
   fwPolicyToInfoModel.vrfExist = True
   fwPolicyToInfoModel.enabled = gv.firewallConfig.enabled
   fwPolicyToInfoModel.setShowDashMeaning( False )
   fwPolicyToInfoModel.setShowNaMeaning( False )

   # Set default vrf if needed
   vrfName = args.get( 'VRF' )
   if ( vrfName is None ) and not gv.hwCapability.intfVlanSupported:
      vrfName = getDefaultVrf( mode )
   if vrfName is not None:
      # Check if the specified VRF exists. If not then bailout.
      fwPolicyToInfoModel.vrfName = vrfName
      if ( vrfName not in gv.hwConfig.vrf ) or (
            vrfName not in gv.firewallConfig.vrf ):
         fwPolicyToInfoModel.vrfExist = False
         return fwPolicyToInfoModel

   if not vrfName:
      fwPolicyToInfoModel.vrfExist = False
      # At present, segment-security configurations are not supported at L2
      return fwPolicyToInfoModel

   # Read source (if any) and destination addresses
   isV4 = args.get( 'DST_IP' )
   dstIp = args.get( 'DST_IP' ) if isV4 else args.get( 'DST_IPV6' )
   srcIp = args.get( 'SRC_IP' ) if isV4 else args.get( 'SRC_IPV6' )
   dstIp = IpGenAddr( str( dstIp ) ) if dstIp is not None else None
   srcIp = IpGenAddr( str( srcIp ) ) if srcIp is not None else None

   # Find the segment(s) dstIp and srcIp each belong to, if any
   def scanSegments( segDirHwConfig, ip ):
      matchLength = 0
      matchingSegments = []
      longestPrefix = None

      def tryMatch( segment, prefix, matchLength ):
         prefixLen = prefix.len
         if not prefix.contains( ip ):
            return matchLength
         if prefixLen == matchLength:
            matchingSegments.append( segment )
         elif prefixLen > matchLength:
            # Longest prefix match
            nonlocal longestPrefix
            longestPrefix = prefix
            matchingSegments.clear()
            matchingSegments.append( segment )
            return prefixLen
         return matchLength

      for segmentName, segment in segDirHwConfig.segment.items():
         for prefix in segment.subnets:
            # Try matching each prefix one by one
            matchLength = tryMatch( segmentName, prefix, matchLength )
      if len( matchingSegments ) > 1:
         # Raise potential undefined behavior if seen
         mode.addWarning(
            f'Segments ({", ".join(matchingSegments)}) match'
            f' {longestPrefix.stringValue} to reach {ip}' )
      return matchingSegments

   segmentDir = gv.hwConfig.vrf.get( vrfName )
   if segmentDir is None:
      return fwPolicyToInfoModel
   fromSegments = scanSegments( segmentDir, srcIp ) if srcIp is not None else None
   toSegments = scanSegments( segmentDir, dstIp )

   # Identify policy between given segments
   def fillRelevantPolicies( segDirHwConfig, toSegment ):
      segInfo = segDirHwConfig.segment.get( toSegment )
      fwPolicyToInfoModel.segments[ toSegment ] = \
         FirewallModels.FwPolicyToInfoSourceSegments()
      outputModel = fwPolicyToInfoModel.segments[ toSegment ].fromSegments
      segmentFactory = FirewallModels.FwPolicyToInfoSourceSegment
      if fromSegments:
         for fromSegment in fromSegments:
            fromSegInfo = segInfo.fromSegment.get( fromSegment )
            if fromSegInfo:
               outputModel[ fromSegment ] = segmentFactory()
               outputModel[ fromSegment ].policy = fromSegInfo.policy.name
      if not fromSegments or len( outputModel ) != len( fromSegments ):
         fwPolicyToInfoModel.setShowDashMeaning( True )
         outputModel[ '-' ] = segmentFactory()
         if segInfo.fallbackPolicy:
            outputModel[ '-' ].policy = segInfo.fallbackPolicy.name
         else:
            fwPolicyToInfoModel.setShowNaMeaning( True )
      if fromSegments is None:
         # No from ip specified, show all policies
         for ( segment, fromSegInfo ) in segInfo.fromSegment.items():
            outputModel[ segment ] = segmentFactory()
            outputModel[ segment ].policy = fromSegInfo.policy.name

   for dst in toSegments:
      fillRelevantPolicies( segmentDir, dst )
   return fwPolicyToInfoModel

# --------------------------------------------------------------------------
# 'show segment-security policy [ policy-name ]' CLI
# --------------------------------------------------------------------------
def showSegmentSecurityPolicyHandler( mode, args ):

   policyName = args.get( 'POLICY_NAME' )
   policyInfo = FirewallModels.FwPolicyInfo()

   def populatePolicyModel( policyName, policyObj ):
      policyModel = policyInfo.Policy( readonly=policyObj.readonly )
      for seqnum, rule in policyObj.rule.items():
         policyModel.policyDefs[ seqnum ] = policyModel.PolicyDef(
            serviceName=rule.serviceName, action=rule.action,
            nexthop=rule.nexthop if rule.action == "statelessRedirect" else None,
            log=rule.log )
      policyInfo.policies[ policyName ] = policyModel

   if not policyName:
      # if no policyName is specified
      for policy, policyObj in gv.firewallConfig.policy.items():
         populatePolicyModel( policy, policyObj )
   elif ( policyObj := gv.firewallConfig.policy.get( policyName ) ) is not None:
      # if policyName is specified
      populatePolicyModel( policyName, policyObj )
   return policyInfo

# --------------------------------------------------------------------------
# 'show segment-security application [ application-name ]' CLI
# --------------------------------------------------------------------------
def showSegmentSecurityApplicationHandler( mode, args ):
   serviceName = args.get( 'APPLICATION_NAME' )
   serviceInfo = FirewallModels.FwServiceInfo()

   def populateServiceModel( serviceName, service ):
      serviceModel = serviceInfo.Service()

      for protoRange in service.protocol.values():
         # NOTE: See BUG813424 about protocol 0
         if ( protoRange.start == 1 ) and ( protoRange.end == 255 ):
            protoRange = [ None ]
         else:
            protoRange = range( protoRange.start, protoRange.end + 1 )
         for proto in protoRange:
            serviceDef = serviceModel.ServiceDef()
            if proto is None:
               serviceDef.allProtocols = True
            else:
               serviceDef.allProtocols = False
               serviceDef.ipProtocol = proto

            def getPorts( coll, serviceDef=serviceDef ):
               return [ serviceDef.PortRange( start=prange.start, end=prange.end )
                        for prange in coll.values() ]

            serviceDef.dstPort = getPorts( service.dstPort )
            serviceDef.srcPort = getPorts( service.srcPort )
            serviceModel.serviceDef.append( serviceDef )
      serviceModel.srcPrefix = list( service.srcIpPrefix )
      serviceModel.dstPrefix = list( service.dstIpPrefix )
      serviceInfo.services[ serviceName ] = serviceModel

   if not serviceName:
      for serviceName, service in gv.hwConfig.services.items():
         populateServiceModel( serviceName, service )
   elif ( service := gv.hwConfig.services.get( serviceName ) ) is not None:
      populateServiceModel( serviceName, service )
   return serviceInfo

# --------------------------------------------------------------------------
# show segment-security sessions [ VRF ]
# --------------------------------------------------------------------------
def showSegmentSecuritySessionsHandler( mode, args ):
   vrfName = args.get( 'VRF' )
   if not vrfName and not gv.hwCapability.intfVlanSupported:
      vrfName = getDefaultVrf( mode )
   for hook in showFirewallSessionHook.extensions():
      hook( mode, vrfName )

def _counterChecker( configGetter, counterGetter, requestTime ):
   return lambda: ( ( not gv.firewallCounters.supportsVrfCounterRequest ) or
                    ( configGetter() is None ) or
                    ( ( ( counter := counterGetter() ) is not None ) and
                      ( counter.lastUpdateTime >= requestTime ) ) )

# --------------------------------------------------------------------------
# 'show segment-security counters [ VRF ]' CLI
# --------------------------------------------------------------------------
def showSegmentSecurityCountersHandler( mode, args ):
   vrfName = args.get( 'VRF' )
   fwCounterModel = FirewallModels.FwCounterInfo()
   if not vrfName and not gv.hwCapability.intfVlanSupported:
      vrfName = getDefaultVrf( mode )

   if vrfName:
      if ( gv.firewallCounters.supportsVrfCounterRequest and
           ( ssoStandbyGuard( mode, None ) is None ) ):
         requestTime = Tac.now()
         requestDir = gv.firewallCountersRequest.force()
         try:
            requestDir.addVrfCounterRequest( vrfName, requestTime )
            Tac.waitFor(
               _counterChecker( lambda: gv.hwConfig.vrf.get( vrfName ),
                                lambda: gv.firewallCounters.vrf.get( vrfName ),
                                requestTime ),
               timeout=5,
               description="counters to be updated",
               sleep="FIREWALL_COHAB_TEST" not in os.environ, )
         except ( Tac.Timeout, KeyboardInterrupt ):
            pass # just ignore and display whatever we have
         finally:
            requestDir.deleteVrfCounterRequest( vrfName, requestTime )
      fwCounterCurrent = gv.firewallCounters.vrf.get( vrfName )
      if fwCounterCurrent is None:
         return fwCounterModel
      fwCounterModel.vrfExist = True
      fwCounterModel.vrf = vrfName
   else:
      fwCounterCurrent = gv.firewallCounters.l2CountersDir
      if fwCounterCurrent is None:
         return fwCounterModel
      fwCounterModel.vrfExist = False
   # default snapshot implementation if no hooks
   if not clearFirewallCountersHook.extensions():
      fwCounter = FwCountersDir( "", UniqueId() )
      fwCounterSnapshot = ( gv.firewallCountersSnapshot.vrf.get( vrfName )
                              if vrfName
                              else gv.firewallCountersSnapshot.l2CountersDir )
      if fwCounterSnapshot is not None:
         fwCounterSnapshot.deleteStale( fwCounterCurrent )
      fwCounter.takeDelta( fwCounterCurrent, fwCounterSnapshot )
   else:
      fwCounter = fwCounterCurrent

   if fwCounter.globalCounters:
      fwCounterModel.flowsCreated = fwCounter.globalCounters.flowCreated
      fwCounterModel.invalidPackets = fwCounter.globalCounters.invalidPkts
      fwCounterModel.bypassPackets = fwCounter.globalCounters.firewallBypass

   for policyName in fwCounter.policy:
      policy = fwCounter.policy[ policyName ]
      policyCounters = FirewallModels.FwCounterInfo.PolicyCounter()
      policyCounters.hits = policy.hitCount
      policyCounters.drops = policy.dropCount
      policyCounters.defaultDrops = policy.defaultDrop
      fwCounterModel.policies[ policyName ] = policyCounters

   def sourceSegmentModel( srcSegCounters ):
      srcSegModel = FirewallModels.FwCounterInfo.DestSegment.SourceSegment()
      if srcSegCounters.policy:
         srcSegModel.policyName = srcSegCounters.policy
      srcSegModel.hits = srcSegCounters.hitCount
      if fwCounter.hasPerSegmentDrops:
         srcSegModel.drops = srcSegCounters.dropCount
         srcSegModel.defaultDrops = srcSegCounters.defaultDrop
      return srcSegModel

   for dstSegName, dstSegCounters in fwCounter.segment.items():
      dstSegModel = FirewallModels.FwCounterInfo.DestSegment()
      if dstSegCounters.defaultCounters is not None:
         dstSegModel.srcSegments[ "-" ] = sourceSegmentModel(
               dstSegCounters.defaultCounters )
      # NB: "-" represents packets not hitting a source segment
      # but "-" is a legal name. Default match counters will be overwritten
      for srcSegName, srcSegCounters in dstSegCounters.segmentCounters.items():
         dstSegModel.srcSegments[ srcSegName ] = sourceSegmentModel(
               srcSegCounters )
      fwCounterModel.dstSegments[ dstSegName ] = dstSegModel
   return fwCounterModel

# -------------------------------------------------------------------------------
# 'show segment-security status [ VRF ] [ segment <segment-name> ]'
# -------------------------------------------------------------------------------
def showSegmentSecurityStatusHandler( mode, args ):
   firewallModel = FirewallModels.FwHwStatusInfo()
   vrfName = args.get( 'VRF' )
   segmentName = args.get( 'SEGMENT_NAME' )
   firewallModel.ipv6PrefixSupported = gv.matchListCapability.ipv6PrefixSupported

   hasFailedSegments = False

   def populateVrfModule( vrfName, vrfSegmentStatus, segmentName ):
      vrfModel = firewallModel.Vrf()
      vrfConfigDir = gv.firewallConfig.vrf.get( vrfName )

      def populateDstSegmentModel( dstSegmentName, dstSegmentStatus ):
         dstSegmentModel = vrfModel.DestSegment()
         # Populate matchlist names for current segment
         if vrfConfigDir and (
               ( dstSegConfig := vrfConfigDir.segment.get( dstSegmentName ) )
               is not None ):
            dstSegmentModel.ipv4PrefixListName = dstSegConfig.ipv4PrefixList
            dstSegmentModel.ipv6PrefixListName = dstSegConfig.ipv6PrefixList
         else:
            dstSegmentModel.ipv4PrefixListName = ""
            dstSegmentModel.ipv6PrefixListName = ""

         def makeSegmentStatus( segStatus ):
            nonlocal hasFailedSegments
            hasFailedSegments = hasFailedSegments or ( segStatus.status == "failed" )
            return dstSegmentModel.SourceSegment( status=segStatus.status )

         dstSegmentModel.srcSegments[ "-" ] = makeSegmentStatus(
               dstSegmentStatus.defaultStatus )
         # NB: "-" represents packets not hitting a source segment
         # but "-" is a legal name. Default match status will be overwritten
         for srcSegmentName, segStatus in dstSegmentStatus.fromSegment.items():
            dstSegmentModel.srcSegments[ srcSegmentName ] = makeSegmentStatus(
                  segStatus )

         vrfModel.dstSegments[ dstSegmentName ] = dstSegmentModel

      if segmentName:
         dstSegmentStatus = vrfSegmentStatus.segment.get( segmentName )
         if dstSegmentStatus is not None:
            populateDstSegmentModel( segmentName, dstSegmentStatus )
      else:
         for dstSegmentName, dstSegmentStatus in (
               vrfSegmentStatus.segmentDir.segment.items() ):
            populateDstSegmentModel( dstSegmentName, dstSegmentStatus )

      firewallModel.vrfs[ vrfName ] = vrfModel

   if vrfName:
      vrfSegmentStatus = gv.hwStatus.status.get( vrfName )
      if vrfSegmentStatus is not None:
         populateVrfModule( vrfName, vrfSegmentStatus, segmentName )
   else:
      # NOTE: SEGMENT_NAME is ignored if VRF is not provided
      # keeping behavior for backwards compatibility
      for vrfName, vrfSegmentStatus in gv.hwStatus.status.items():
         populateVrfModule( vrfName, vrfSegmentStatus, None )

   if hasFailedSegments and gv.hwCapability.errorReportingSupported:
      mode.addWarning(
         'Some segments failed to install. '
         'See "show segment-security errors" for details about the failures.' )

   return firewallModel

# -------------------------------------------------------------------------------
# 'show segment-security errors [ VRF [ segment <segment-name> ] ]'
# -------------------------------------------------------------------------------
def showSegmentSecurityErrorsHandler( mode, args ):
   firewallModel = FirewallModels.FwHwErrorsInfo()
   vrfName = args.get( 'VRF' )
   segmentName = args.get( 'SEGMENT_NAME' )

   def populateVrfModule( vrfName, vrfSegmentStatus, segmentName ):
      vrfModel = firewallModel.Vrf( dstSegments=None )

      def populateDstSegmentModel( dstSegmentName, dstSegmentStatus ):
         dstSegmentModel = vrfModel.DestSegment()

         def addStatus( fromSegmentName, segStatus ):
            if ( segStatus.status != "failed" ) or ( segStatus.reason is None ):
               return
            dstSegmentModel.srcSegments[ fromSegmentName
                  ] = dstSegmentModel.SourceSegment( error=segStatus.reason )
         # NB: "-" represents packets not hitting a source segment
         # but "-" is a legal name. Default match errors will be overwritten
         addStatus( '-', dstSegmentStatus.defaultStatus )
         for srcSegmentName, segStatus in dstSegmentStatus.fromSegment.items():
            addStatus( srcSegmentName, segStatus )
         if not dstSegmentModel.srcSegments:
            return
         if vrfModel.dstSegments is None:
            vrfModel.dstSegments = {}
         vrfModel.dstSegments[ dstSegmentName ] = dstSegmentModel

      if segmentName:
         dstSegmentStatus = vrfSegmentStatus.segment.get( segmentName )
         if dstSegmentStatus is not None:
            populateDstSegmentModel( segmentName, dstSegmentStatus )
      else:
         vrfModel.error = vrfSegmentStatus.error
         for dstSegmentName, dstSegmentStatus in vrfSegmentStatus.segment.items():
            populateDstSegmentModel( dstSegmentName, dstSegmentStatus )
      firewallModel.vrfs[ vrfName ] = vrfModel

   if vrfName:
      vrfSegmentStatus = gv.hwStatus.status.get( vrfName )
      if vrfSegmentStatus is not None:
         populateVrfModule( vrfName, vrfSegmentStatus, segmentName )
   else:
      for vrfName, vrfSegmentStatus in gv.hwStatus.status.items():
         populateVrfModule( vrfName, vrfSegmentStatus, None )
   return firewallModel

def Plugin( em ):
   gv.firewallConfig = LazyMount.mount(
         em, 'firewall/config/cli', 'Firewall::Config', 'r' )
   gv.firewallCounters = LazyMount.mount(
         em, 'firewall/counters', 'Firewall::Counters', 'r' )
   gv.firewallCountersRequest = LazyMount.mount(
         em, 'firewall/countersRequest', 'Firewall::CountersRequestDir', 'w' )
   gv.firewallCountersSnapshot = LazyMount.mount(
         em, 'firewall/snapshot/counters', 'Firewall::Counters', 'r' )
   gv.hwCapability = LazyMount.mount(
         em, 'firewall/hw/capability', 'Firewall::HwCapability', 'r' )
   gv.hwConfig = LazyMount.mount(
         em, 'firewall/hw/config', 'Firewall::HwConfig', 'r' )
   gv.hwStatus = LazyMount.mount(
         em, 'firewall/hw/status', 'Firewall::HwStatus', 'r' )
   gv.matchListCapability = LazyMount.mount(
         em, 'matchlist/hw/capability', 'MatchList::HwCapability', 'r' )
