#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from ArnetModel import IpGenericAddress
from CliModel import Dict
from CliModel import Enum
from CliModel import Int
from CliModel import List
from CliModel import Model
from CliModel import Str
from CliPlugin.TunnelModels import TunnelTableEntry, Via
from CliPlugin.TunnelCli import getNhAndIntfStrs
from CliPlugin.TunnelCli import getViaModelFromViaDict
from TableOutput import createTable, Format
import six

class BgpLuTunnelTableEntry( TunnelTableEntry ):
   vias = List( valueType=Via, help="List of nexthops" )
   labels = List( valueType=str, help="Label stack" )
   contribution = Str( help="Contribution of tunnel entry to Tunnel RIB" )
   bgpMetric = Int( help="BGP Metric for tunnel entry" )
   bgpMetric2 = Int( help="BGP Metric 2 for tunnel entry" )
   bgpMetricType = Enum( help="BGP Metric Type for tunnel entry",
                         values=( 'metric', 'MED', 'AIGP' ), optional=True )
   bgpPref = Int( help="BGP Preference for tunnel entry" )
   bgpPref2 = Int( help="BGP Preference 2 for tunnel entry" )

   def getPrefStr( self ):
      return 'Yes' if self.contribution == 'contributing' else 'No'

   def renderBgpLuTunnelTableEntry( self, table, tunnelIndex ):
      labelsStr = '[ ' + ' '.join( self.labels ) + ' ]'
      nhStr = intfStr = '-'
      if self.vias:
         firstVia = self.vias[ 0 ]
         nhStr, intfStr = getNhAndIntfStrs( firstVia )
         # the desired label stack will either be in the entry CAPI model label
         # stack or in the vias in the case of an LU Push entry via, but not both
         if 'labels' in self.vias[ 0 ]:
            labelsStr = '[ ' + ' '.join( self.vias[ 0 ][ 'labels' ] ) + ' ]'
      bgpMetricStr = str( self.bgpMetric ) + ' ' + self.bgpMetricType.ljust( 6 )
      table.newRow( tunnelIndex, str( self.endpoint ), nhStr, intfStr, labelsStr,
                    self.getPrefStr(), bgpMetricStr, str( self.bgpMetric2 ),
                    str( self.bgpPref ), str( self.bgpPref2 ) )

      # For vias after the first via, print a dash for these fields:
      # 'Index', 'Endpoint', 'Contributing', 'Metric', 'Metric2', 'Pref', and 'Pref2'
      # since they are common for each tunnel. For the 'Labels' field, we have to
      # check if the label-stack is stored in the tunnel entry's 'labels' attribute
      # or in each individual via. If each via has a label-stack, we have to print
      # each label-stack per via. Otherwise, we can assume the tunnel entry has
      # one label-stack, which we print only for the first via when rendering.
      for via in self.vias[ 1 : ]:
         nhStr, intfStr = getNhAndIntfStrs( via )
         if 'labels' in via:
            labelsStr = '[ ' + ' '.join( via[ 'labels' ] ) + ' ]'
         else:
            labelsStr = '-'
         table.newRow( '-', '-', nhStr, intfStr, labelsStr, '-', '-', '-', '-', '-' )

class BgpLuTunnelTable( Model ):
   __revision__ = 2
   entries = Dict( keyType=int, valueType=BgpLuTunnelTableEntry,
                   help="BGP LU tunnel table entries keyed by tunnel index" )

   def render( self ):
      headings = ( "Index", "Endpoint", "Nexthop/Tunnel Index", "Interface",
                   "Labels", "Contributing", "Metric/Type", "Metric 2",
                   "Pref", "Pref 2" )
      fl = Format( justify='left' )
      fr = Format( justify='right' )
      table = createTable( headings )
      table.formatColumns( fl, fl, fl, fl, fl, fl, fr, fl, fl, fl )
      for tunnelIndex, bgpLuTunnelTableEntry in sorted(
            six.iteritems( self.entries ) ):
         bgpLuTunnelTableEntry.renderBgpLuTunnelTableEntry( table, tunnelIndex )

      print( table.output() )

   # degrade function is called for,
   # 'show bgp labeled-unicast tunnel [<tunnel-idx>] | json revision 1'
   def degrade( self, dictRepr, revision ):
      if revision == 1:
         for entry in six.itervalues( dictRepr[ 'entries' ] ):
            for via in entry[ 'vias' ]:
               if via:
                  getViaModelFromViaDict( via ).degradeToV1( via )
      return dictRepr

# -------------------------------------------------------------------------------
# "show bgp labeled-unicast forwarding [<tunnel-id>]"
# -------------------------------------------------------------------------------
class BgpLuForwardingTunnelTableEntry( Model ):
   vias = List( valueType=Via, help="List of nexthops" )
   labels = List( valueType=str, help="Label stack" )

   def renderBgpLuForwardingTunnelTableEntry( self, table, tunnelIndex ):
      labelsStr = '[ ' + ' '.join( self.labels ) + ' ]'
      nhStr = intfStr = '-'
      if self.vias:
         firstVia = self.vias[ 0 ]
         nhStr, intfStr = getNhAndIntfStrs( firstVia )
         # the desired label stack will either be in the entry CAPI model label
         # stack or in the vias in the case of an LU Push entry via, but not both
         if 'labels' in self.vias[ 0 ]:
            labelsStr = '[ ' + ' '.join( self.vias[ 0 ][ 'labels' ] ) + ' ]'
      table.newRow( tunnelIndex, nhStr, intfStr, labelsStr )
      for via in self.vias[ 1 : ]:
         nhStr, intfStr = getNhAndIntfStrs( via )
         if 'labels' in via:
            labelsStr = '[ ' + ' '.join( via[ 'labels' ] ) + ' ]'
         table.newRow( '-', nhStr, intfStr, labelsStr )

class BgpLuForwardingTunnelTable( Model ):
   entries = Dict( keyType=int, valueType=BgpLuForwardingTunnelTableEntry,
                   help="BGP LU Forwarding entries keyed by tunnel index" )

   def render( self ):
      headings = ( "Index", "Nexthop", "Interface", "Labels" )
      fl = Format( justify='left' )
      table = createTable( headings )
      table.formatColumns( fl, fl, fl, fl )
      for tunnelIndex, lsTunnelTableEntry in sorted( six.iteritems( self.entries ) ):
         lsTunnelTableEntry.renderBgpLuForwardingTunnelTableEntry( table,
                                                                   tunnelIndex )

      print( table.output() )

# -------------------------------------------------------------------------------
# "show bgp tunnel udp
# -------------------------------------------------------------------------------
class BgpUdpTunnelTableEntry( Model ):
   vias = List( valueType=Via, help="List of nexthops" )
   endpoint = IpGenericAddress( help="Destination address for tunnel entry" )
   sourceAddr = IpGenericAddress( help="Source address for tunnel entry" )
   tos = Int( help="TOS of tunnel entry" )
   payloadType = Enum( help="Payload type for tunnel entry",
                       values=( 'mpls', 'ip', 'ipv6', 'ipvx', ) )

   def payloadTypeToStr( self, payloadType ):
      if payloadType == 'ipvx':
         return "IP"
      elif payloadType == 'ipv4':
         return "IPv4"
      elif payloadType == 'ipv6':
         return "IPv6"
      elif payloadType == 'mpls':
         return "MPLS"
      return None

   def renderBgpUdpTunnelTableEntry( self, table, tunnelIndex ):
      for via in self.vias:
         endpointStr = str( self.endpoint )
         nhStr = str( via.nexthop )
         intfStr = via.interface.stringValue
         sourceAddrStr = str( self.sourceAddr )
         tosStr = str( self.tos )
         payloadStr = self.payloadTypeToStr( self.payloadType )
         table.newRow( tunnelIndex, endpointStr, nhStr, intfStr, sourceAddrStr,
                       tosStr, payloadStr )

class BgpUdpTunnelTable( Model ):
   entries = Dict( keyType=int, valueType=BgpUdpTunnelTableEntry,
                   help="BGP UDP tunnel entries keyed by tunnel index" )

   def render( self ):
      headings = ( "Index", "Endpoint", "Next Hop", "Interface", "Source Address",
                   "TOS", "Payload Type" )
      fl = Format( justify='left' )
      fr = Format( justify='right' )
      table = createTable( headings )
      table.formatColumns( fr, fl, fl, fl, fl, fr, fl )
      for ( tunnelIndex, udpTunnelTableEntry ) in sorted( self.entries.items() ):
         udpTunnelTableEntry.renderBgpUdpTunnelTableEntry( table, tunnelIndex )

      print( table.output() )
