#!/usr/bin/env python3
# Copyright (c) 2009,2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import IpLibConsts
from SysConstants.in_h import IPPROTO_UDP, IPPROTO_TCP
import Tac
import Toggles.AclToggleLib

NON_DEFAULT_VRFNAME = 'management'

# All types of ACLs
aclTypes = [ 'ip', 'ipv6', 'mac' ]
aclTypeDisplayNames = {
   'ip' : 'IPv4',
   'ipv6' : 'IPv6',
   'mac' : 'MAC'
   }
serviceAclTypes = [ 'ip', 'ipv6' ]

# All directions of ACLs
aclDirections = [ 'in', 'out' ]

# max rule sequence number (32-bit)
MAX_SEQ = 0xFFFFFFFF # 32-bit integer
# max DSCP value
MAX_DSCP = 63
# max IP length value
MAX_IP_PACKET_LEN = 65535
# Max IPv6 flow-label value
MAX_IPV6_FLOW_LABEL = ( 1 << 20 ) - 1

IPPROTO_IGMP = 2
IPPROTO_IP_IN_IP = 4
IPPROTO_RSVP = 46
IPPROTO_ESP = 50
IPPROTO_OSPF = 89
IPPROTO_PIM = 103
IPPROTO_VRRP = 112
IPPROTO_AHP = 51
ICMP_ALL = 0xFFFF
ICMP6_ALL = 0xFFFF

IPPROTO_GRE = 0x2F
IPPROTO_SCTP = 132
IPPROTO_MPLSOVERGRE = 137
TNI_MAX = ( 2 ** 24 ) -1
VNI_MAX = ( 2 ** 24 ) -1
GRE_ALL = ( 2 ** 16 ) - 1
GREPROTO_MAX = ( 2 ** 16 ) - 1 
NVGRE_PROTO = 0x6558
MPLS_PROTO = 0x8847
MPLSLABEL_MAX = ( 2** 20 ) - 1
TEID_MAX = ( 2 ** 32 ) - 1
USERL4_MAX = ( 2 ** 32 ) - 1
PAYLOAD_MAX = ( 2 ** 32 ) - 1

MACVLAN_ALL = 0xFFFF
MACCOS_ALL = 0xFF
MACPROTO_ALL = 0xFFFFFFFF
MACPROTO_IP4 = 0x0800

# convenient Tac.Value type definition
AclDirection = Tac.Type( 'Acl::AclDirection' )
AclNameMaps = Tac.singleton( "Acl::AclNameMaps" )
AclType = Tac.Type( 'Acl::AclType' )
AclAction = Tac.Type( 'Acl::Action' )
CopyToCpuDst = Tac.Type( 'Acl::CopyToCpuDst' )
TtlValue = Tac.Type( "Acl::TtlSpec" )
TtlRangeValue = Tac.Type( "Acl::TtlRangeSpec" )
PortValue = Tac.Type( "Acl::PortSpec" )
IpLenValue = Tac.Type( "Acl::IpLenSpec" )
PayloadValue = Tac.Type( "Acl::PayloadValue" )
PayloadSpec = Tac.Type( "Acl::PayloadSpec" )
PayloadStringConstants = Tac.Type( "Acl::PayloadStringConstants" )
MetadataValue = Tac.Type("Acl::MetadataValue")
MetadataSpec = Tac.Type( "Acl::MetadataSpec" )
IpFilterValue = Tac.Type( "Acl::IpFilter" )
IpRuleValue = Tac.Type( "Acl::IpRuleConfig" )
Ip6FilterValue = Tac.Type( "Acl::Ip6Filter" )
Ip6RuleValue = Tac.Type( "Acl::Ip6RuleConfig" )
MacFilterValue = Tac.Type( "Acl::MacFilter" )
MplsFilterValue = Tac.Type( "Acl::MplsFilter" )
MplsLabelFilterValue = Tac.Type( "Acl::MplsLabelFilter" )
MacRuleValue = Tac.Type( "Acl::MacRuleConfig" )
TcpFlagValue = Tac.Type( "Acl::TcpFlag" )
TcamBankSharingMode = Tac.Type( "Acl::TcamBankSharingMode" )
RemarkConfigValue = Tac.Type( "Acl::RemarkConfig" )
Icmp6RuleType = Tac.Type( "Acl::ImplicitIcmp6RulesType" )
EcnValue = Tac.Type( "Acl::Ecn" )
FlowLabelMatchType = Tac.Type( "Acl::FlowLabelMatchType" )
AclPermitResponseType = Tac.Type( "Acl::AclPermitResponseType" )
AclPayloadHdrStart = Tac.Type( "Acl::AclPayloadHdrStart" )
HwAclMechanism = Tac.Type( "Acl::HwAclMechanism" )
ActionDuringAclUpdate = Tac.Type( "Acl::ActionDuringAclUpdate" )
EthType = Tac.Type( "Arnet::EthType" )

# constants
anyIpAddr = Arnet.AddrWithFullMask( '0.0.0.0', 0 )
anyIp6Addr = Arnet.Ip6AddrWithMask( '::/0' )
anyIp6AddrWithFullMask = Arnet.Ip6AddrWithFullMask( '::/::' )
zeroMacAddr = Tac.Value( "Arnet::EthAddr", 0, 0, 0 )
zeroMacAddrString = '00:00:00:00:00:00'
anyTtlValue = TtlValue( 'any', 0 )
anyTtlRangeValue = TtlRangeValue( 'any', '' )
anyIpLenValue = IpLenValue('any', '')
anyPortValue = PortValue( 'any', '' )
vxlanPortValue = PortValue( 'eq', '4789' )
anyFlowLabelValue = 0

defaultVrf = IpLibConsts.DEFAULT_VRF

def gtpStrFromProto( gtpProto ):
   return AclNameMaps.gtpProto.get( gtpProto, '' )

def gtpPortFromStr( gtpProto ):
   return gtpProtoByName[ gtpProto ].port

# Generate a 64-bit ID
def genUniqueId( ):
   return Tac.Value("Ark::UniqueId")

# services common to both TCP and UDP
commonServiceByName_ = AclNameMaps.commonTcpUdpServiceByName

tcpServiceByName = AclNameMaps.service[ IPPROTO_TCP ].description
udpServiceByName = AclNameMaps.service[ IPPROTO_UDP ].description

gtpProtoByName = AclNameMaps.gtpProtoDesc

sctpServiceByName = AclNameMaps.service[ IPPROTO_SCTP ].description

serviceMap = { IPPROTO_TCP: tcpServiceByName,
               IPPROTO_UDP: udpServiceByName,
               IPPROTO_SCTP: sctpServiceByName,
             }

aristaTcpServiceByName = AclNameMaps.aristaTcpServiceByName

def getServByName( proto, name ):
   if protoMap := serviceMap.get( proto ):
      if serviceDesc := protoMap.get( str( name ) ):
         return serviceDesc.port
   return None

def getServiceMap( proto ):
   return serviceMap.get( proto )

tcpFlagFin = 1
tcpFlagSyn = 2
tcpFlagRst = 4
tcpFlagPsh = 8
tcpFlagAck = 16
tcpFlagUrg = 32

tcpFlagTokens = {
   tcpFlag : 1 << tcpFlagIndex
      for tcpFlag, tcpFlagIndex in AclNameMaps.tcpFlagToken.items()
}

def setTcpFlag( flag, flagName ):
   if flagName == 'fin':
      flag.fin = True
   elif flagName == 'syn':
      flag.syn = True
   elif flagName == 'rst':
      flag.rst = True
   elif flagName == 'psh':
      flag.psh = True
   elif flagName == 'ack':
      flag.ack = True
   elif flagName == 'urg':
      flag.urg = True
   else:
      assert False, "Unknown flag!"

def tcpFlag( flagName ):
   flag = TcpFlagValue()
   setTcpFlag( flag, flagName )
   return flag

def tcpFlags( flagNameList ):
   flag = TcpFlagValue()
   for flagName in flagNameList:
      setTcpFlag( flag, flagName )
   return flag

tcpFlagAll = tcpFlags( [ 'syn', 'fin', 'rst', 'ack', 'psh', 'urg' ] )

def tcpFlagFromString( token ):
   return tcpFlagTokens[ token ]

def tcpFlagsFromValue( flag, split=' ' ):
   # turn a TcpFlag into a string
   flags = [ k for k, v in tcpFlagTokens.items() if flag & v ]
   return split.join( sorted( flags ) )

# helper function to compare and copy ACLs
def copyNonInstantiatingCollection( dest, source ):
   # first add/update new members
   for k, v in source.items():
      if not k in dest or dest[ k ] != v:
         dest[ k ] = v

   # then delete non-existing members         
   if len( dest ) > len (source ):
      for k, v in dest.items():
         if not k in source:
            del dest[ k ]
            if len( dest ) == len (source ):
               break

def collEq( c1, c2 ):
   if len( c1 ) != len( c2 ):
      return False
   for k, v in c1.items():
      if not k in c2 or c2[ k ] != v:
         return False
   return True

def copyAclSubConfig( dest, source, aclType ):
   assert dest != source

   dest.countersEnabled = source.countersEnabled

   if aclType == 'ip':
      copyNonInstantiatingCollection( dest.ipRuleById,
                                      source.ipRuleById )
   elif aclType == 'ipv6':
      copyNonInstantiatingCollection( dest.ip6RuleById,
                                      source.ip6RuleById )
   else:
      assert aclType == 'mac'
      copyNonInstantiatingCollection( dest.macRuleById,
                                      source.macRuleById )      

   copyNonInstantiatingCollection( dest.remarkBySequence,
                                   source.remarkBySequence )

   copyNonInstantiatingCollection( dest.ruleBySequence, source.ruleBySequence )

def inverseMacMask( mask ):
   import Ethernet, re # pylint: disable=import-outside-toplevel
   # inverse bits of a canonical mac address
   m = re.match( Ethernet.colonPattern, mask )
   assert m is not None
   words = [ ]
   for i in range( 6 ):
      words.append( int( m.group( i + 1 ), 16 ) )
   # pylint: disable-next=consider-using-f-string
   return "%02x:%02x:%02x:%02x:%02x:%02x" % (
      words[ 0 ] ^ 0xFF, words[ 1 ] ^ 0xFF, words[ 2 ] ^ 0xFF,
      words[ 3 ] ^ 0xFF, words[ 4 ] ^ 0xFF, words[ 5 ] ^ 0xFF )

def createIPAcl( config, aclName, aclType='ip' ):
   id0 = Tac.Value( 'Ark::UniqueId', 12345, 765, 0 )
   id1 = Tac.Value( 'Ark::UniqueId', 12345, 765, 1 )
   id2 = Tac.Value( 'Ark::UniqueId', 12345, 765, 2 )

   acl = config.config[ aclType ].newAcl( aclName,
         config.config[ aclType ].type, True, False )

   acl.newSubConfig( id0 )
   if aclType == 'ip':
      acl.subConfig[ id0 ].ipRuleById[ id1 ] = Tac.Value( 'Acl::IpRuleConfig' )
      acl.subConfig[ id0 ].ipRuleById[ id2 ] = Tac.Value( 'Acl::IpRuleConfig' )
   else:
      assert aclType == 'ipv6'
      acl.subConfig[ id0 ].ip6RuleById[ id1 ] = Tac.Value( 'Acl::Ip6RuleConfig' )
      acl.subConfig[ id0 ].ip6RuleById[ id2 ] = Tac.Value( 'Acl::Ip6RuleConfig' )
   acl.currCfg = config.config[ aclType ].acl[ aclName ].subConfig[ id0 ]
   acl.currCfg.ruleBySequence[ 0 ] = id1
   acl.currCfg.ruleBySequence[ 1 ] = id2
   return ( acl, id0, id1, id2 )

def payloadOptToPayloadSpec( payloadOpt=None, headerStart=None, headerType=None ):
   # Remove duplicate, sort and create payloadSpec
   payloadOptDict = {}
   if payloadOpt is not None:
      for payload in payloadOpt:
         if payload.mask != 0:
            payloadOptDict.update( { payload.offset : payload } )
   offsetPattern = ''
   list1 = []
   for offset in sorted(payloadOptDict.keys()):
      payload = payloadOptDict[ offset ]
      list1.append( payload.offset )
      list1.append( payload.pattern )
      list1.append( payload.mask )
      alias = payload.alias

      if payload.patternOverride:
         alias += PayloadStringConstants.overrideDelim + \
                  PayloadStringConstants.patternOverride
      if payload.maskOverride:
         alias += PayloadStringConstants.overrideDelim + \
                  PayloadStringConstants.maskOverride
      list1.append( alias )

   offsetPattern = ' '.join( str( e ) for e in list1 )
   payloadSpec = PayloadSpec( offsetPattern )
   if payloadOptDict:
      if headerStart:
         payloadSpec.headerStart = headerStart
      if headerType:
         payloadSpec.headerType = headerType
   return payloadSpec

def metadataListToMetadataSpec( metadataList=None ):
   # Remove duplicate, sort by source and offset
   metadataDict = {}
   if metadataList is not None:
      for metadata in metadataList:
         if metadata.mask != 0:
            metadataDict.update( 
                  { ( metadata.metadataSource, metadata.offset ) : metadata } )
   offsetPattern = ''
   l = []
   for key in sorted( metadataDict.keys() ):
      metadata = metadataDict[ key ]
      l.append( metadata.metadataSource )
      l.append( metadata.offset )
      l.append( metadata.pattern )
      l.append( metadata.mask )

   offsetPattern = ' '.join( str( e ) for e in l )
   metadataSpec = MetadataSpec( offsetPattern )
   return metadataSpec

def _getFilter( proto, srcIp=None, dstIp=None,
                dport=None, dportOper='eq',
                sport=None, sportOper='eq',
                tracked=False, ttl=None,
                ttlRange=None, ttlRangeOper='eq',
                ipLen=None, ipLenOper='eq', 
                aclType='ip', flowLabel=None,
                etherType=None ):
   portSpecs = []
   for port, oper in [ ( dport, dportOper ), ( sport, sportOper ) ]:
      if port is not None:
         ports = []
         for p in port:
            serv = getServByName( proto, p )
            if serv is None:
               serv = p
            ports.append( str( serv ) )
         ports = ' '.join( ports )
         portSpec = PortValue( oper, ports )
      else:
         portSpec = anyPortValue
      portSpecs.append( portSpec ) 

   if ipLen is not None:
      ipLens = []
      for p in ipLen:
         ipLens.append( p )
      ipLens = ' '.join( ipLens )
      ipLenSpec = IpLenValue( ipLenOper, ipLens )
   else:
      ipLenSpec = anyIpLenValue
   if ttl is None:
      ttlSpec = anyTtlValue
   else:
      ttlSpec = TtlValue( 'eq', ttl )
   if ttlRange is None:
      ttlRangeSpec = anyTtlRangeValue
   else:
      ttlRangeSpec = TtlRangeValue( ttlRangeOper, ttlRange )
   if etherType is None:
      etherType = 0
   
   # set protoMask if you want to match on any protocol number
   # else keep protoMask as unset
   aclIpProto0Enabled = Toggles.AclToggleLib.toggleAclIpProto0Enabled()
   if aclIpProto0Enabled and proto != 0:
      protoMask = 0xff
   else:
      protoMask = 0

   if aclType == 'ip':
      if srcIp is None:
         srcIp = anyIpAddr
      if dstIp is None:
         dstIp = anyIpAddr
      return IpFilterValue( proto=proto,
                            protoMask=protoMask,
                            source=srcIp,
                            destination=dstIp,
                            innerSource=anyIpAddr,
                            innerDest=anyIpAddr,
                            innerSource6=anyIp6Addr,
                            innerDest6=anyIp6Addr,
                            dport=portSpecs[ 0 ],
                            sport = portSpecs[ 1 ],
                            tracked=tracked,
                            ttl=ttlSpec,
                            ttlRange=ttlRangeSpec,
                            ipLen=ipLenSpec,
                            etherType=etherType )
   else:
      if srcIp is None:
         srcIp = anyIp6Addr
      if dstIp is None:
         dstIp = anyIp6Addr
      if flowLabel is None:
         flowLabelMask = 0
         flowLabel = anyFlowLabelValue
      else:
         flowLabelMask = 0x000FFFFF
      return Ip6FilterValue( proto=proto,
                             protoMask=protoMask,
                             source=srcIp,
                             destination=dstIp,
                             innerSource=anyIpAddr,
                             innerDest=anyIpAddr,
                             innerSource6=anyIp6Addr,
                             innerDest6=anyIp6Addr,
                             dport=portSpecs[ 0 ],
                             sport=portSpecs[ 1 ],
                             tracked=tracked,
                             ttl=ttlSpec,
                             ttlRange=ttlRangeSpec,
                             flowLabel=flowLabel,
                             flowLabelMask=flowLabelMask,
                             etherType=etherType )

def _addRule( acl, seqnum, action, proto, srcIp=None, dstIp=None,
              dport=None, dportOper='eq', sport=None, sportOper='eq',
              tracked=False, ttl=None, ttlRange=None, ttlRangeOper='eq',
              ipLen = None, ipLenOper='eq',aclType='ip' ):
   uniqueId = None
   if action == 'remark':
      # Currently there are no remarks in default config so this should never
      # hit. Adding this for reference if we decide to add default remarks later
      assert False, "Please add remark rule to default config"
      rule = RemarkConfigValue( remark='Some remark' )
      acl.remarkBySequence[ seqnum ] = rule
   else:
      ipFilter = _getFilter( proto, srcIp, dstIp, dport, dportOper, sport, sportOper,
                             tracked, ttl, ttlRange, ttlRangeOper, ipLen, ipLenOper,
                             aclType )
      uniqueId = genUniqueId( )
      if aclType == 'ip':
         rule = IpRuleValue( filter=ipFilter,
                             action=action,
                             log=False )
         acl.ipRuleById[ uniqueId ] = rule
      else:
         rule = Ip6RuleValue( filter=ipFilter,
                              action=action,
                              log=False )
         acl.ip6RuleById[ uniqueId ] = rule
      acl.ruleBySequence[ seqnum ] = uniqueId
   return uniqueId

def newServiceAclTypeVrfMap( config, serviceAclMap,
                             aclName, vrf=defaultVrf, aclType='ip',
                             proto=IPPROTO_TCP, port=646, defaultAction='deny',
                             sport=None, createAcl=False, standardAcl=False,
                             readonly=False, addToInputCol=False ):
   aclTypeConfig = config.config.newMember( aclType )
   aclConfig = aclTypeConfig.acl.get( aclName )
   if aclConfig is None and createAcl:
      if addToInputCol:
         config.config[ aclType ].acl.newMember( aclName, 
               config.config[ aclType ].type, standardAcl, readonly )
         aclConfig = config.config[ aclType ].acl[ aclName ]
      else:
         aclConfig = Tac.newInstance( "Acl::AclConfig", aclName, aclType,
                                      standardAcl, readonly )
         config.config[ aclType ].acl.addMember( aclConfig )
   # Create empty acl.currCfg
   setAclRules( aclConfig, [], aclType )
   key = Tac.Value( "Acl::AclTypeAndVrfName", aclType, vrf )
   serviceAclMap.aclName[ key ] = aclName
   return aclConfig

def newServiceAcl( aclConfig, aclCpConfig, serviceName, aclName, vrf=defaultVrf, 
                   aclType='ip', proto=IPPROTO_TCP, port=None, defaultAclName='',
                   defaultAction='deny', sport=None, createAcl=False,
                   standardAcl=False, readonly=False, tracked=False ):
   # proto and port are actually ignored
   cpConfig = aclCpConfig.cpConfig.newMember( aclType )
   config = aclConfig.config.newMember( aclType )
   serviceAcl = cpConfig.serviceAcl.newMember( vrf )
   service = serviceAcl.service.get( serviceName )
   if service is None:
      service = serviceAcl.service.newMember( serviceName, proto, vrf )
   if port:
      service.ports = ','.join( str( i ) for i in port )
   if sport:
      service.sports = ','.join( str( i ) for i in sport )
   service.tracked = tracked
   service.defaultAclName = defaultAclName
   service.defaultAction = defaultAction
   service.aclName = aclName
   acl = config.acl.get( aclName )
   if acl is None and createAcl:
      acl = config.newAcl( aclName, config.type, standardAcl, readonly )
   return acl

def setAclRules( acl, rules, aclType='ip', readonly=True, countersEnabled=True ):
   aclSubConfig = acl.newSubConfig( genUniqueId() )
   aclSubConfig.readonly = readonly
   aclSubConfig.countersEnabled = countersEnabled
   seqnum = 10
   ruleIds = []
   for rule in rules:
      ruleId = _addRule( aclSubConfig, seqnum, rule[ 'action' ], rule[ 'proto' ],
                  rule.get( 'srcIp' ), rule.get( 'dstIp' ),
                  rule.get( 'dport' ), rule.get( 'dportOper', 'eq' ),
                  rule.get( 'sport' ), rule.get( 'sportOper', 'eq' ),
                  False, rule.get( 'ttl' ), rule.get( 'ttlRange' ),
                  rule.get( 'ttlRangeOper', 'eq' ), rule.get( 'ipLen' ),
                  rule.get( 'ipLenOper', 'eq' ), aclType )
      seqnum += 10
      ruleIds.append( ruleId )
   acl.currCfg = aclSubConfig
   return ruleIds

def initializeDefaultCpAclName( entMan, cfg, paramCfg ):
   # pkgdeps: rpm Epoch-lib
   hwEpochStatus = entMan.lookup( "hwEpoch/status" )
   def _handleActive( active ):
      if not active or not hwEpochStatus.serviceAclEnabled:
         return
      # create and use open default cp acl only when service acl is used
      paramCfg.cpAclNameDefault = '' 
   entMan.registerActiveCallback( _handleActive )

def getAclTypeDisplayName( aclType ):
   return aclTypeDisplayNames.get( aclType ) 
