# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import MultiRangeRule
import Tracing

from AclCliLib import genericIpProtocols
from socket import IPPROTO_UDP, IPPROTO_TCP, IPPROTO_ICMP, IPPROTO_ICMPV6
AddressFamily = Tac.Type( "Arnet::AddressFamily" )
TrieGen = Tac.Type( 'Routing::TrieGen' )
IpGenPrefix = Tac.Type( 'Arnet::IpGenPrefix' )
t2 = Tracing.t2

# URL schemes supported for importing field-sets
listUrlLocalSchemes = { "file:", "flash:", "drive:" }
listUrlNetworkSchemes = { "http:", "scp:", "https:" }

# IP Protocol matchers
tcpUdpProtocols = {
   'tcp': ( IPPROTO_TCP, 'TCP' ),
   'udp': ( IPPROTO_UDP, 'UDP' )
}
extraIpv4Protocols = {
   'icmp': ( IPPROTO_ICMP, 'Internet Control Message Protocol' ),
   'tcp': ( IPPROTO_TCP, 'TCP' ),
   'udp': ( IPPROTO_UDP, 'UDP' )
}

extraIpv6Protocols = {
   'icmpv6': ( IPPROTO_ICMPV6, 'Internet Control Message Protocol version 6' ),
   'tcp': ( IPPROTO_TCP, 'TCP' ),
   'udp': ( IPPROTO_UDP, 'UDP' )
}

tcpUdpProtocols = {
   'tcp': ( IPPROTO_TCP, 'TCP' ),
   'udp': ( IPPROTO_UDP, 'UDP' )
}

def getProtectedFieldSetNames( field ):
   assert field in [ 'prefix', 'port', 'vlan', 'integer', 'mac', 'service' ]
   protectedFsNames = {
      'port': [ 'field-set', 'destination' ],
      'prefix': [ 'field-set' ],
      'vlan': [ 'field-set' ],
      'integer': [ 'field-set' ],
      'mac': [ 'field-set' ],
      'service': [ 'field-set' ],
   }
   return protectedFsNames[ field ]

# ICMP v4 types { token: (type, helpdesc) }
icmpV4Types = {
   "echo-reply": ( 0, "Echo reply" ),
   "unreachable": ( 3, "Destination unreachable" ),
   "source-quench": ( 4, "Source quench" ),
   "redirect": ( 5, "Redirect" ),
   "alternate-address": ( 6, "Alternate host address" ),
   "echo": ( 8, "Echo" ),
   "router-advertisement": ( 9, "Router advertisement" ),
   "router-solicitation": ( 10, "Router solicitation" ),
   "time-exceeded": ( 11, "Time exceeded" ),
   "parameter-problem": ( 12, "Parameter problem" ),
   "timestamp-request": ( 13, "Timestamp requests" ),
   "timestamp-reply": ( 14, "Timestamp replies" ),
   "information-request": ( 15, "Information requests" ),
   "information-reply": ( 16, " Information replies" ),
   "mask-request": ( 17, "Address mask request" ),
   "mask-reply": ( 18, "Address mask replies" ),
   "traceroute": ( 30, "Traceroute" ),
   "conversion-error": ( 31, "Datagram conversion error" ),
   "mobile-host-redirect": ( 32, "Mobile host redirect" ),
}

# ICMP v6 types { token: (type, helpdesc) }
icmpV6Types = {
   "unreachable": ( 1, "Destination unreachables" ),
   "packet-too-big": ( 2, "Packet too big" ),
   "time-exceeded": ( 3, "Time exceeded" ),
   "parameter-problem": ( 4, "Parameter problem" ),
   "echo-request": ( 128, "Echo request" ),
   "echo-reply": ( 129, "Echo reply" ),
   "mld-query": ( 130, "Multicast listener query" ),
   "mld-report": ( 131, "Multicast listener report" ),
   "mld-done": ( 132, "Multicast listener done" ),
   "router-solicitation": ( 133, "Router solicitation" ),
   "router-advertisement": ( 134, "Router advertisement" ),
   "neighbor-solicitation": ( 135, "Neighbor solicitation" ),
   "neighbor-advertisement": ( 136, "Neighbor advertisement" ),
   "redirect-message": ( 137, "Redirect message" ),
   "renum": ( 138, "Router renumbering" ),
   "icmp-node-query": ( 139, "ICMP node information queries" ),
   "icmp-node-response": ( 140, "ICMP node information responses" ),
   "inverse-nd-solicitation":
   ( 141, "Inverse Neighbor Discovery Solicitation Message" ),
   "inverse-nd-advertise":
   ( 142, "Inverse Neighbor Discovery Advertisement Message" ),
   "mldv2-reports": ( 143, "Multicast Learner Discovery (MLDv2) reports" ),
   "ha-ad-request": ( 144, "Home Agent Address Discovery Request Message" ),
   "ha-ad-reply": ( 145, "Home Agent Address Discovery Reply Message" ),
   "mp-solicit": ( 146, "Mobile Prefix Solicitation" ),
   "mp-advertise": ( 147, "Mobile Prefix Advertisement" ),
   "cert-path-solicitation": ( 148, "Certificate Path Solicitation" ),
   "cert-path-advertise": ( 149, "Certificate Path Advertisement" ),
   "fmipv6-message": ( 150, "FMIPv6 Messages" ),
   "mcast-router-advertise": ( 151, "Multicast Router Advertisement" ),
   "mcast-router-solicit": ( 152, "Multicast Router Solicitation" ),
   "mcast-router-terminate": ( 153, "Multicast Router Termination" ),
   "rpl-control-message": ( 155, "RPL Control Message" )
}

# ICMP v4 type names/values that have valid ICMP codes { token: (type, helpdesc) }
icmpV4TypeWithValidCodes = {
   "unreachable": ( 3, "Destination unreachable" ),
   "redirect": ( 5, "Redirect" ),
   "router-advertisement": ( 9, "Router advertisement" ),
   "time-exceeded": ( 11, "Time exceeded" ),
   "parameter-problem": ( 12, "Parameter problem" ),
   "3": ( 3, "Destination unreachable value" ),
   "5": ( 5, "Redirect value" ),
   "9": ( 9, "Router advertisement value" ),
   "11": ( 11, "Time exceeded value" ),
   "12": ( 12, "Parameter problem value" ),
}

# ICMP v6 type names/values that have valid ICMP codes { token: (type, helpdesc) }
icmpV6TypeWithValidCodes = {
   "unreachable": ( 1, "Destination unreachables" ),
   "time-exceeded": ( 3, "Time exceeded" ),
   "parameter-problem": ( 4, "Parameter problem" ),
   "1": ( 1, "Destination unreachables" ),
   "3": ( 3, "Time exceeded" ),
   "4": ( 4, "Parameter problem" ),
}

# ICMP v4 codes mapping {type: { token: (code, helpdesc) } }
icmpV4Codes = {
   3: { "net-unreachable": ( 0, "Net unreachable" ),
        "host-unreachable": ( 1, "Host unreachable" ),
        "protocol-unreachable": ( 2, "Protocol unreachable" ),
        "port-unreachable": ( 3, "Port unreachable" ),
        "packet-too-big": ( 4, "Fragmentation needed but DF was set" ),
        "source-route-failed": ( 5, "Source route failed" ),
        "network-unknown": ( 6, "Network unknown" ),
        "host-unknown": ( 7, "Host unknown" ),
        "host-isolated": ( 8, "Source host isolated" ),
        "dod-net-prohibited": ( 9, "Communication with network prohibited" ),
        "dod-host-prohibited": ( 10, "Communication with host prohibited" ),
        "net-tos-unreachable": ( 11, "Network unreachable for type of service" ),
        "host-tos-unreachable": ( 12, "Host unreachable for type of service" ),
        "administratively-prohibited":
        ( 13, "Communication administratively prohibited" ),
        "host-precedence-unreachable": ( 14, "Host precedence violation" ),
        "precedence-unreachable": ( 15, "Precedence cutoff in effect" ) },
   5: { "net-redirect": ( 0, "Network redirect" ),
        "host-redirect": ( 1, "Host redirect" ),
        "net-tos-redirect": ( 2, "Network and type of service redirect" ),
        "host-tos-redirect": ( 3, "Host and type of service redirect" ) },
   9: { "normal-router-advertisement": ( 0, "Normal router advertisement" ),
        "not-route-common-traffic": ( 16, "Does not route common traffic" ) },
   11: { "ttl-exceeded": ( 0, "Time to live exceeded in transit" ),
         "reassembly-timeout": ( 1, "Fragment reassembly time exceeded" ) },
   12: { "general-parameter-problem": ( 0, "General parameter problem" ),
         "option-missing": ( 1, "Missing a required option" ),
         "no-room-for-option": ( 2, "Bad length for parameter" ) }
}

# ICMP v6 codes mapping {type: { token: (code, helpdesc) } }
icmpV6Codes = {
   1: { "no-route": ( 0, "No route to destination" ),
        "no-admin": ( 1, "Administration prohibited destination" ),
        "beyond-scope": ( 2, "Beyond scope of source address" ),
        "address-unreachable": ( 3, "Address unreachable" ),
        "port-unreachable": ( 4, "Port unreachable" ),
        "source-address-failed":
        ( 5, "Source address failed ingress/egress policy" ),
        "reject-route": ( 6, "Reject route to destination" ),
        "source-routing-error": ( 7, "Error in source routing header" ) },
   3: { "hop-limit-exceeded": ( 0, "Hop limit exceeded in transit" ),
        "fragment-reassembly-exceeded":
        ( 1, "Fragment reassembly time exceeded" ) },
   4: { "erroneous-header": ( 0, "All erroneous header field encountered" ),
        "unrecognized-next-header":
        ( 1, "Unrecognized next header type encountered" ),
        "unrecognized-ipv6-option": ( 2, "Unrecognized IPv6 option encountered" ) }
}

fieldSetTypeToStr = {
   'ipv4' : 'IPv4 prefix',
   'ipv6' : 'IPv6 prefix',
   'l4-port' : 'L4-port',
   'vlan' : 'VLAN',
   'integer': 'Integer',
   'mac': 'MAC',
}

def getKeywordMap( *maps ):
   # From maps generate a keyword -> helpdesc map that can be
   # used by DynamicKeywordMatcher.
   # pylint: disable-next=consider-using-f-string
   return { name: "%s (%d)" % ( value[ 1 ], value[ 0 ] )
            for m in maps for name, value in m.items() }

@Tac.memoize
def getProtocolNumToNameMap():
   protoNumToName = {}
   for name in extraIpv4Protocols: # pylint: disable=consider-using-dict-items
      protoNumToName[ extraIpv4Protocols[ name ][ 0 ] ] = name
   for name in genericIpProtocols: # pylint: disable=consider-using-dict-items
      protoNumToName[ genericIpProtocols[ name ][ 0 ] ] = name
   return protoNumToName

def rangeSetToNumericalRange( rangeSet, rangeType ):
   '''
   @rangeSet: python set data-structure.
   @rangeType: Type of object, e.g: Classification::PortRange

   Given a set of values, returns a list of rangeType objects.
   '''
   # Generate a list of single range pairs ex: [ ( 1, 1 ), ( 2, 2 ), ( 3, 3 ) ]
   # then pass list into packRanges to return [ ( 1, 3 ) ] we then
   # use packRanges to convert to numericalRange
   rangeSet = [ ( s, s ) for s in rangeSet ]
   packedSet = MultiRangeRule.packRanges( rangeSet )

   if not packedSet:
      return []

   return [ Tac.Value( rangeType, lBound, uBound ) for lBound, uBound in packedSet ]

def numericalRangeToSet( numRangeList ):
   '''
   @numRangeList: Object of type Classification::NumericalRange.

   Given a list of Classification::NumericalRange objects, produce
   a set of values, Ex: ( [ 3, 4, 5, 6, 7, 90, 91, 92, 93, 94, 95 ] )
   '''
   resultSet = set()
   for x in numRangeList:
      resultSet.update( range( x.rangeStart, x.rangeEnd + 1 ) )
   return resultSet

def computePortRangeSet( rangeSetString ):
   '''
   @rangeSetString: Set represented as a string. ex: '1,5,7,10-20'

   Computes a set object, which is identical to the port range string
   passed in, but in a more convenient form( a set of integers, rather than
   a comma-seperated string ).'''
   return MultiRangeRule.multiRangeFromCanonicalString( rangeSetString )

def numericalRangeToRangeString( numRange ):
   '''
   Given a list of Classification::NumericalRange objects, produce
   a range in string form, 3, 7, 90-95
   '''
   # pylint: disable-next=unnecessary-comprehension
   numRangeList = [ aRange for aRange in numRange ]
   numRangeList.sort()
   return ', '.join( [ x.stringValue() for x in numRangeList ] )

def addOrDeleteRange( numericalRangeList, modifySet, rangeType, add=True ):
   '''
   Add/delete a set of values, Ex. ([ 3, 7, 10 ]) from a range
   represented by a list of Classification::NumericalRange objs.
   The resultant list is returned in the form of Classification::NumericalRange objs.
   '''

   if not modifySet:
      # when adding to a range, modifySet must be non-empty
      assert not add
      # return an empty list which will cause the current
      # numerical range list to be cleared.
      return []

   # Convert numericalRangeList to a set of integers
   numRangeSet = numericalRangeToSet( numericalRangeList )

   if add:
      # Find the union of the two sets
      resultSet = numRangeSet | modifySet
   else:
      # Subtract modifySet from the originalSet
      resultSet = numRangeSet - modifySet

   # Create a list of NumericalRange objects from the resulting range string
   return rangeSetToNumericalRange( resultSet, rangeType )

def getFieldSetConfigForType( fieldSetType, fieldSetName, fieldSetConfig ):
   if not fieldSetConfig:
      return None
   if fieldSetType == 'ipv4':
      fieldSetColl = fieldSetConfig.fieldSetIpPrefix
   elif fieldSetType == 'ipv6':
      fieldSetColl = fieldSetConfig.fieldSetIpv6Prefix
   elif fieldSetType == 'l4-port':
      fieldSetColl = fieldSetConfig.fieldSetL4Port
   elif fieldSetType == 'vlan':
      fieldSetColl = fieldSetConfig.fieldSetVlan
   elif fieldSetType == 'integer':
      fieldSetColl = fieldSetConfig.fieldSetInteger
   elif fieldSetType == 'mac':
      fieldSetColl = fieldSetConfig.fieldSetMacAddr
   else:
      # pylint: disable-next=consider-using-f-string
      assert False, 'Unhandled fsType: %s' % fieldSetType
   fieldSetCfg = fieldSetColl.get( fieldSetName )
   return fieldSetCfg

def getTcpFlagAndMasksEst():
   # ack | rst
   newTcpFlag = Tac.Value( "Classification::TcpFlag" )
   newTcpFlagMask = Tac.Value( "Classification::TcpFlag" )
   newTcpFlag.ack = True
   newTcpFlagMask.ack = True
   tcpFlagAndMaskAck = Tac.Value( "Classification::TcpFlagAndMask",
                                  newTcpFlag, newTcpFlagMask )

   newTcpFlag = Tac.Value( "Classification::TcpFlag" )
   newTcpFlagMask = Tac.Value( "Classification::TcpFlag" )
   newTcpFlag.rst = True
   newTcpFlagMask.rst = True
   tcpFlagAndMaskRst = Tac.Value( "Classification::TcpFlagAndMask",
                                  newTcpFlag, newTcpFlagMask )
   return [ tcpFlagAndMaskAck, tcpFlagAndMaskRst ]

def getTcpFlagAndMaskInit():
   # !ack & syn
   newTcpFlag = Tac.Value( "Classification::TcpFlag" )
   newTcpFlagMask = Tac.Value( "Classification::TcpFlag" )
   newTcpFlag.ack = False
   newTcpFlagMask.ack = True
   newTcpFlag.syn = True
   newTcpFlagMask.syn = True
   tcpFlagAndMaskInit = Tac.Value( "Classification::TcpFlagAndMask",
                                   newTcpFlag, newTcpFlagMask )
   return tcpFlagAndMaskInit

# policy and match have different tag for representing
# that is identified by descTypePolicy - match/policy description.
def getIdFromDesc( desc=None, descTypePolicy=True ):
   found = False
   # default id=0
   if desc is None:
      return 0
   if descTypePolicy:
      if "#policy-id=" in desc:
         coll = desc.split( "=" )
         found = True
   else:
      if "#rule-id=" in desc:
         coll = desc.split( "=" )
         found = True
   # tag found but not a digit, return id=0
   descId = 0
   if found:
      try:
         descId = int( coll[ 1 ] )
      except ValueError as e:
         t2( 'ValueError', e )
         descId = 0
   return descId
