# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

from ArnetModel import IpGenericPrefix
from CliModel import Model, Bool, Enum, Dict, List, Int, Str, Submodel
from IntfModels import Interface
from TableOutput import createTable, Format
import Tac
import sys
import CliExtensions

configuredDecapGroupPrefixLen = CliExtensions.CliHook()

class DecapGroupModelBase( Model ):
   'Base model definition for all ip decap-group types'
   tunnelType = Enum( help='Decap tunnel type',
                      values=Tac.Type( 'Tunnel::Decap::TunnelType' ).attributes )
   persistent = Bool( 'Decap group is present in system configuration' )
   forwardingVrf = Str( help='Forwarding VRF', optional=True )

class DecapGroupCounterEntry( Model ):
   packets = Int( help="Number of received packets on DecapGroup" )
   octets = Int( help="Number of received octets on DecapGroup" )

class DecapIntfModel( Model ):
   intfId = Interface( 'Decap destination interface' )
   addressFamily = Enum( help='Decap address family',
                         values=Tac.Type( 'Arnet::AddressFamily' ).attributes )
   addressType = Enum( help='Decap address type',
                       values=Tac.Type( 'Tunnel::Decap::AddressType' ).attributes )
   oldConfig = Bool( 'Old way of writing this config' )
   counter = Submodel( valueType=DecapGroupCounterEntry,
                       help='Counter for the decap interface.',
                       optional=True )

class DecapGroupGreModel( DecapGroupModelBase ):
   'GRE decap-group'
   decapIntf = List( help='decap intf key list',
                     valueType=DecapIntfModel )
   decapIp = List( help='decap ip list', valueType=IpGenericPrefix )
   decapIpCounters = Dict( keyType=IpGenericPrefix,
                           valueType=DecapGroupCounterEntry,
                           help='Map decap IP address to decap group counter',
                           optional=True )

class DecapGroupWithOnlyGreKeyToVrfMapModel( Model ):
   'GRE key to forwarding VRF mapping for a GRE decap-group'
   greKeyToForwardingVrfMapping = Dict( keyType=int, valueType=str,
                                        help='Map GRE key to forwarding VRF' )

class DecapGroupIpIpModel( DecapGroupModelBase ):
   'IP-in-IP decap-group'
   decapIntf = List( help='decap intf key list',
                     valueType=DecapIntfModel )
   decapIp = List( help='decap ip list', valueType=IpGenericPrefix )
   decapIpCounters = Dict( keyType=IpGenericPrefix,
                           valueType=DecapGroupCounterEntry,
                           help='Map decap IP address to decap group counter',
                           optional=True )

class DecapGroupUdpModel( DecapGroupModelBase ):
   'UDP decap-group'
   decapIntf = List( help='decap intf key list',
                     valueType=DecapIntfModel )
   decapIp = List( help='decap ip list', valueType=IpGenericPrefix )
   destinationPort = Int( help='Decap destination port' )
   payloadType = Enum( help='Decap Payload Type',
               values=Tac.Type( 'Tunnel::Decap::PayloadType' ).attributes )
   decapIpCounters = Dict( keyType=IpGenericPrefix,
                           valueType=DecapGroupCounterEntry,
                           help='Map decap IP address to decap group counter',
                           optional=True )

class DecapGroupQosTcModel( DecapGroupModelBase ):
   qosTcFromMplsTc = Bool( help='QoS Traffic Class derived from MPLS Traffic Class',
                           default=False )

class DecapGroupsWithOnlyGreKeyToVrfMap( Model ):
   decapGroups = Dict( help='A mapping from name to GRE decap-group entry',
                       valueType=DecapGroupWithOnlyGreKeyToVrfMapModel )
  
   def _getColumnFormat( self, **kwargs ):
      baseFormat = { 'justify': 'left' }
      baseFormat.update( kwargs )
      fmt = Format( **baseFormat )
      fmt.noPadLeftIs( True )
      return fmt

   def render( self ):
      if not self.decapGroups:
         return

      headings = ( 'Decap Group', 'GRE Key', 'Forwarding VRF' )
      table = createTable( headings )
      table.formatColumns( *( [ self._getColumnFormat() ] * len( headings ) ) )

      for name, dg in sorted( self.decapGroups.items() ):
         if dg.greKeyToForwardingVrfMapping:
            for greKey, vrf in dg.greKeyToForwardingVrfMapping.items():
               table.newRow( name, greKey, vrf )
         else:
            table.newRow( name )
      print( table.output() )

class DecapGroups( Model ):
   __revision__ = 4
   decapGroups = Dict( help='A mapping from name to ip decap-group entry',
                       valueType=DecapGroupModelBase )
   globalUdpDestPortToPayloadType = \
         Dict( keyType=int, valueType=str,
               help='A global mapping from UDP destination port to payload type' )

   def _getColumnFormat( self, **kwargs ):
      baseFormat = { 'justify': 'left', 'border': True }
      baseFormat.update( kwargs )
      if 'maxWidth' in kwargs:
         baseFormat[ 'wrap' ] = True
      fmt = Format( **baseFormat )
      fmt.noPadLeftIs( True )
      return fmt

   def _tunnelTypeName( self, tunnelType ):
      # pylint: disable-next=import-outside-toplevel
      from CliPlugin.DecapGroupCli import tacTunnelType as TunnelType
      _names = {
         TunnelType.ipip: 'IP-in-IP',
         }
      if tunnelType in _names:
         return _names[ tunnelType ]
      else:
         return tunnelType.upper()

   def _isIp4Addr( self, ip ):
      return ip.af == 'ipv4'

   def countersPresent( self ):
      '''
      If there is at least one decap group with counters present in the model,
      return True.  Otherwise, return False.
      '''
      for dg in self.decapGroups.values():
         if dg.tunnelType not in [ 'gre', 'ipip', 'udp' ]:
            continue
         if dg[ 'decapIpCounters' ] is not None:
            return True
         for decapIntf in dg[ 'decapIntf' ]:
            if decapIntf[ 'counter' ] is not None:
               return True
      return False

   def configuredDecapGroupLen( self ):
      for hook in configuredDecapGroupPrefixLen.extensions():
         decapGroupPrefixLen = hook()
         return decapGroupPrefixLen
      return None

   def render( self ):
      if not self.decapGroups:
         return

      countersPresent = self.countersPresent()
      decapGroupConfiguredLen = self.configuredDecapGroupLen()
      headings = []
      formats = []
      headings.append( 'D' )
      fmt = self._getColumnFormat()
      fmt.padLimitIs( True )
      formats.append( fmt )

      headings.append( 'Name' )
      formats.append( self._getColumnFormat( maxWidth=20 ) )
      headings.append( 'Type' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Info' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Version' )
      fmt = self._getColumnFormat()
      fmt.padLimitIs( True )
      formats.append( fmt )
      headings.append( 'Address Type' )
      formats.append( self._getColumnFormat() )
      headings.append( 'UDP Dest Port' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Payload Type' )
      formats.append( self._getColumnFormat() )
      headings.append( 'Forwarding VRF' )
      formats.append( self._getColumnFormat( maxWidth=14, border=countersPresent ) )
      if countersPresent:
         # DecapGroup counters are present in the output, add the counter columns.
         headings.append( 'Packets' )
         formats.append( fmt )
         headings.append( 'Octets' )
         formats.append( self._getColumnFormat( border=True ) )

      table = createTable( tuple( headings ), tableWidth=200 )
      table.formatColumns( *formats )

      # pylint: disable-next=import-outside-toplevel
      from CliPlugin.DecapGroupCli import tacTunnelType as TunnelType
      if TunnelType is None:
         TunnelType = Tac.Type( 'Tunnel::Decap::TunnelType' )
      # pylint: disable=too-many-nested-blocks
      for name in sorted( self.decapGroups ):
         values = []
         dgOnInteface = []
         addressFamily = []
         addressType = []
         dport = []
         payloadType = []
         packetCounts = []
         octetCounts = []
         dg = self.decapGroups[ name ]
         count = 0
         if dg.tunnelType != TunnelType.unknown:
            for key in dg.decapIntf:
               count += 1
               if key.intfId == Tac.Value( 'Arnet::IntfId', '' ):
                  values.append( 'all' )
               else:
                  values.append( key.intfId.stringValue )
               if key.addressFamily == 'ipv4':
                  addressFamily.append( 'IPv4' )
               else:
                  addressFamily.append( 'IPv6' )
               addressType.append( key.addressType )
               if countersPresent:
                  packetCounts.append( key.counter.packets )
                  octetCounts.append( key.counter.octets )
               else:
                  packetCounts.extend( [ '' ] )
                  octetCounts.extend( [ '' ] )
               dgOnInteface.append( True )
            values.extend( dg.decapIp )
            for ip in dg.decapIp:
               count += 1
               if self._isIp4Addr( ip ):
                  addressFamily.extend( [ 'IPv4' ] )
               else:
                  addressFamily.extend( [ 'IPv6' ] )
               if countersPresent:
                  counter = dg.decapIpCounters[ ip ]
                  packetCounts.append( counter.packets )
                  octetCounts.append( counter.octets )
               else:
                  packetCounts.extend( [ '' ] )
                  octetCounts.extend( [ '' ] )
               dgOnInteface.append( False )
            addressType.extend( [ '' ] * len( dg.decapIp ) )
            if count == 0:
               count = 1

            if dg.tunnelType == TunnelType.udp:
               if not dg.destinationPort and \
                  len( self.globalUdpDestPortToPayloadType ) > 0:
                  # Use global value if any
                  portList=[]
                  ptypeList=[]
                  for k, v in sorted(self.globalUdpDestPortToPayloadType.items(), \
                          key=lambda item: item[0]):
                     portList.append( k )
                     if v == 'ipvx':
                        ptypeList.append( 'ip' )
                     else:
                        ptypeList.append( v )
                  dport.extend( [ portList ] * count )
                  payloadType.extend( [ ptypeList ] * count )
               else:      
                  dport.extend( [ dg.destinationPort ] * count )
                  payloadType.extend( [ dg.payloadType ] * count )
            else:
               dport.extend( [ '' ] * count )
               payloadType.extend( [ '' ] * count )

         if len( values ) == 0:
            values.append( '' )
            addressFamily.append( '' )
            addressType.append( '' )
            dport.append( '' )
            payloadType.append( '' )
            if countersPresent:
               packetCounts.append( '' )
               octetCounts.append( '' )
            dgOnInteface.append( False )
         if countersPresent:
            for each, family, addType, port, pt, packets, octets, intf in zip(
                  values, addressFamily, addressType, dport, payloadType,
                  packetCounts, octetCounts, dgOnInteface ):
               vals = []
               dynamic = ' ' if dg.persistent else '*'
               vals.append( dynamic )
               vals.append( name )
               vals.append( self._tunnelTypeName( dg.tunnelType ) )
               # If the decap group has an invalid tunnel type or
               # has a valid tunnel type and either no decap-ip or decap-intf
               # then we set values to ''. Add a check to see if the variable
               # 'each' is really an ip address before we do stringValue
               # pylint: disable-next=singleton-comparison
               if ( decapGroupConfiguredLen != None and
                    isinstance( each, Tac.Type( "Arnet::IpGenPrefix" ) ) and
                    intf != True ): # pylint: disable=singleton-comparison
                  addr, mask = each.stringValue.split( '/' )
                  if mask == str( 32 ):
                     vals.append( addr + '/' + str( decapGroupConfiguredLen ) + '*' )
                  else:
                     vals.append( each )
               else:
                  vals.append( each )
               vals.append( family )
               vals.append( addType )
               if isinstance( port, list ):
                  vals.append( port[ 0 ] )
               else:
                  vals.append( port )
               if isinstance( pt, list ):
                  vals.append( pt[ 0 ] )
               else:
                  vals.append( pt )
               if dg.forwardingVrf:
                  vals.append( dg.forwardingVrf )
               else:
                  vals.append( ' ' )
               vals.append( packets )
               vals.append( octets )
               table.newRow( *vals )

               if isinstance( port, list ):
                  for p, ptype in zip( port[ 1: ], pt[ 1: ] ):
                     vals = [ '', '', '', '', '', '' ]
                     vals.append( p )
                     vals.append( ptype )
                     vals.extend( [ '', '' ] )
                     table.newRow( *vals )
         else:
            for each, family, addType, port, pt, intf in zip( values, addressFamily,
               addressType, dport, payloadType, dgOnInteface ):
               vals = []
               dynamic = ' ' if dg.persistent else '*'
               vals.append( dynamic )
               vals.append( name )
               vals.append( self._tunnelTypeName( dg.tunnelType ) )
               # If the decap group has an invalid tunnel type or
               # has a valid tunnel type and either no decap-ip or decap-intf
               # then we set values to ''. Add a check to see if the variable
               # 'each' is really an ip address before we do stringValue
               # pylint: disable-next=singleton-comparison
               if ( decapGroupConfiguredLen != None and
                    isinstance( each, Tac.Type( "Arnet::IpGenPrefix" ) ) and
                    intf != True ): # pylint: disable=singleton-comparison
                  addr, mask = each.stringValue.split( '/' )
                  if mask == str(32):
                     vals.append( addr + '/' + str( decapGroupConfiguredLen ) + '*' )
                  else:
                     vals.append( each )
               else:
                  vals.append( each )
               vals.append( family )
               vals.append( addType )
               if isinstance( port, list ):
                  vals.append( port[ 0 ] )
               else:
                  vals.append( port )
               if isinstance( pt, list ):
                  vals.append( pt[ 0 ] )
               else:
                  vals.append( pt )
               if dg.forwardingVrf:
                  vals.append( dg.forwardingVrf )
               table.newRow( *vals )
              
               if isinstance( port, list ):
                  for p, ptype in zip( port[ 1: ], pt[ 1: ] ):
                     vals = [ '', '', '', '', '', '' ]
                     vals.append( p )
                     vals.append( ptype )
                     table.newRow( *vals )

      # Print note about 'D' column
      print( 'NOTE: "D" column indicates dynamic entries' )

      if decapGroupConfiguredLen != None: # pylint: disable=singleton-comparison
         print( 'Global IPv4 prefix length:', decapGroupConfiguredLen,
               '("*" indicates affected entries)' )

      # Render the table output
      sys.stdout.write( table.output() )

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         for entry in dictRepr[ 'decapGroups' ].values():
            if entry[ 'tunnelType' ] == 'gre':
               if len( entry[ 'decapIp' ] ) > 0:
                  if '/' in entry[ 'decapIp' ][ 0 ]:
                     addr, _ = entry[ 'decapIp' ][ 0 ].split( '/' )
                     entry[ 'decapIp' ] = addr
                  else:
                     entry[ 'decapIp' ] =  entry[ 'decapIp' ][ 0 ]
               else:
                  entry[ 'decapIp' ] =  '0.0.0.0'
            else:
               del entry[ 'decapIp' ]
               if len( entry[ 'decapIntf' ] ) > 0:
                  if not isinstance( entry[ 'decapIntf' ][ 0 ], str ):
                     entry[ 'decapIntf' ] = entry[ 'decapIntf' ][ 0 ][ 'intfId' ]
      elif revision <= 3:
         for entry in dictRepr[ 'decapGroups' ].values():                            
            if len( entry[ 'decapIp' ] ) > 0:                                        
               for index in range( len( entry[ 'decapIp' ] ) ):                      
                  if '/' in entry[ 'decapIp' ][ index ]:                             
                     addr, _ = entry[ 'decapIp' ][ index ].split( '/' )              
                     entry[ 'decapIp' ][ index ] =  addr                             
            if revision == 2 and entry[ 'tunnelType' ] == 'ipip':                    
               if len( entry[ 'decapIntf' ] ) > 0:                                   
                  if not isinstance( entry[ 'decapIntf' ][ 0 ], str ):               
                     entry[ 'decapIntf' ] = entry[ 'decapIntf' ][ 0 ][ 'intfId' ]

      return dictRepr

class DecapGroupQos( Model ):
   groups = Dict( help='IP decap groups, keyed by their name',
                       valueType=DecapGroupQosTcModel )

   def render( self ):
      if not self.groups:
         return

      headings = ( "Name", "Type", "QoS Traffic Class" )
      fln = Format( justify='left' )
      fln.noPadLeftIs( True )
      flt = Format( justify='left', minWidth=7, maxWidth=7, padding=0 )
      flt.noPadLeftIs( True )
      flt.padLimitIs( True )
      flq = Format( justify='left', minWidth=20, maxWidth=20 )
      flq.noPadLeftIs( True )
      table = createTable( headings, tableWidth=100 )
      table.formatColumns( fln, flt, flq )
      for key in sorted( self.groups ):
         qosModel = self.groups[ key ]

         qosType = "Outer DSCP"
         if qosModel.qosTcFromMplsTc:
            qosType = "MPLS Traffic Class"
         table.newRow( key, qosModel.tunnelType.upper(), qosType )

      print( table.output() )
