#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
from CliModel import Model
from CliModel import List, Dict, Enum, Int, Str, Bool
from CliPlugin.TunnelModels import tunnelTypeEnumValues
from TableOutput import createTable, Format, TableFormatter
from TypeFuture import TacLazyType

invalidTc = Tac.Type( 'Qos::TrafficClass' ).invalid
invalidDscp = Tac.Type( 'Arnet::DscpValue' ).invalid
vskOverrideEntryStatus = \
      TacLazyType( "Mpls::Override::LfibVskOverrideStatusEntry::Status" )
AddressFamily = TacLazyType( "Arnet::AddressFamily" )

statusStrDict = {
   'unknownStatus': 'unknown status',
   'programmed': 'programmed',
   'outOfResource': 'out of resource',
   'fecNotFound': 'FEC not found',
   'hwNotReady': 'HW not ready',
   'hwRequestPending': 'HW request pending',
   'labelActionNotSupported': 'label action not supported',
}

class SrTePolicyCbfFecOverride( Model ):
   ipVersion = Enum( values=[ 'ipv4', 'ipv6' ], help="IP version" )
   dscp = Int( help="DSCP value" )
   defaultFecId = Int( help="Default FEC ID" )
   overrideFecId = Int( help="Overriding SR-TE Policy FEC ID" )

class SrTePolicyCbfFecModel( Model ):
   overrides = List( valueType=SrTePolicyCbfFecOverride,
                     help="Policy FEC override entry" )

   def cmpKey( self, item ):
      return ( item.ipVersion, item.defaultFecId, item.dscp )

   def render( self ):
      #  IP Version
      fl = Format( justify='left' )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )

      #  Default fecId, DSCP, Overriding fecId
      fr = Format( justify='right' )
      fr.noPadLeftIs( True )
      fr.padLimitIs( True )

      headings = ( "IP Version", "Default FEC ID", "DSCP", "Overriding FEC ID" )
      table = createTable( headings )
      table.formatColumns( fl, fr, fr, fr )
      for overrideEntry in sorted( self.overrides, key=self.cmpKey ):
         afStr = 'IPv4' if overrideEntry.ipVersion == AddressFamily.ipv4 else 'IPv6'
         table.newRow( afStr, overrideEntry.defaultFecId,
                       overrideEntry.dscp, overrideEntry.overrideFecId )
      print( table.output() )

class SrTePolicyCbfFecVrfModel( Model ):
   vrfs = Dict( keyType=str, valueType=SrTePolicyCbfFecModel,
                help='Per VRF Segment Routing Traffic Engineering CBF override'
                'entries' )

   def render( self ):
      for vrfSummary in self.vrfs.values():
         vrfSummary.render()

class CbfFecOverride( Model ):
   ipVersion = Enum( values=[ 'ipv4', 'ipv6' ], help="IP version" )
   dscp = Int( help="DSCP value", optional=True )
   tc = Int( help="Traffic class value", optional=True )
   defaultFecId = Int( help="Default FEC ID" )
   overrideFecId = Int( help="Overriding SR-TE Policy FEC ID" )
   tunnelType = Enum( values=tunnelTypeEnumValues,
                      help='Tunnel Types' )

class CbfFecModel( Model ):
   overrides = List( valueType=CbfFecOverride, help="CBF FEC override entry" )

   def cmpKey( self, item ):
      return ( item.ipVersion, item.defaultFecId, item.dscp )

   def getIpVersionCli( self, ipVersion ):
      if ipVersion == 'ipv4':
         return 'IPv4'
      else:
         return 'IPv6'

   def getCbfMode( self ):
      '''
      Figure cbfMode from the contents of self.overrides.
      We assume that self.overrides will not have mixed override entries
      ( some DSCP and some Traffic class based ) as we don't support this now.
      We pick up the first override entry, see if it is DSCP/TC based.
      We return False for DSCP based CBF or if self.overrides contains nothing.
      We return True for TC based CBF
      '''

      if not self.overrides:
         return False

      overrideEntry = self.overrides[ 0 ]
      return overrideEntry.tc is not None

   def render( self ):
      # Figure cbfMode from the first overrideEntry in the overrides table
      # and print the output accordingly
      if self.getCbfMode():
         headings = ( "IP Version", "Default FEC ID", "Traffic Class",
                      "Overriding FEC ID", "Tunnel Type" )
      else:
         headings = ( "IP Version", "Default FEC ID", "DSCP", "Overriding FEC ID",
                      "Tunnel Type" )

      table = createTable( headings )
      f1 = Format( justify='left' )
      f2 = Format( justify='right' )
      f1.padLimitIs( True )
      table.formatColumns( f1, f2, f2, f2, f1 )

      for overrideEntry in sorted( self.overrides, key=self.cmpKey ):
         table.newRow( self.getIpVersionCli( overrideEntry.ipVersion ),
                       f"0x{overrideEntry.defaultFecId:x}",
                       overrideEntry.tc if self.getCbfMode() else overrideEntry.dscp,
                       f"0x{overrideEntry.overrideFecId:x}",
                       overrideEntry.tunnelType )

      print( table.output() )

class CbfFecVrfModel( Model ):
   vrfs = Dict( keyType=str, valueType=CbfFecModel,
                help='Per VRF Traffic Engineering CBF override entries' )

   def render( self ):
      for vrfSummary in self.vrfs.values():
         vrfSummary.render()

class CbfMplsOverride( Model ):
   defaultVskSource = Str( help="Source of Default LfibVsk" )
   defaultVskIndex = Int( help="Index of Default LfibVsk" )
   defaultVias = List( valueType=str, help="Tunnel names of default mpls vias" )
   overrideVskSource = Str( help="Source of Override LfibVsk" )
   overrideVskIndex = Int( help="Index of Override LfibVsk" )
   overrideVias = List( valueType=str, help="Tunnel names of override mpls vias" )
   tc = Int( help="Traffic class value" )
   indexId = Bool( help="Set if specific Default LfibVsk's index is passed",
                   optional=True )
   status = Enum( values=vskOverrideEntryStatus.attributes,
                  help="MPLS CBF override status" )

class CbfMplsModel( Model ):
   overrides = List( valueType=CbfMplsOverride, help='MPLS CBF override LfibVsks' )

   def cmpKey( self, item ):
      return ( item.defaultVskIndex, item.tc )

   def render( self ):
      headings = ( "Default LfibViaSet", "Traffic Class", "Overriding LfibViaSet",
                   "Status" )
      table = createTable( headings )
      f1 = Format( justify='right' )
      f2 = Format( justify='left' )
      f3 = Format( justify='left', wrap=True )
      f4 = Format( justify='left' )
      table.formatColumns( f2, f1, f2, f4 )

      for entry in sorted( self.overrides, key=self.cmpKey ):
         table.newRow( f"Source: {entry.defaultVskSource}, "
                       f"Index: {entry.defaultVskIndex}",
                       entry.tc,
                       f"Source: {entry.overrideVskSource}, "
                       f"Index: {entry.overrideVskIndex}",
                       statusStrDict[ entry.status ] )

         defVskTable = TableFormatter()
         defVskTable.formatColumns( f3 )
         defVskTable.newRow( "Tunnels:" )
         defViasCount = len( entry.defaultVias )
         for i in range( defViasCount ):
            if not entry.indexId and i == 5:
               defVskTable.newRow( f"( this Via Set has {defViasCount} tids only "
                                   f"showing 5 )" )
               break
            defVskTable.newRow( entry.defaultVias[ i ] )

         ovrVskTable = TableFormatter()
         ovrVskTable.formatColumns( f3 )
         ovrVskTable.newRow( "Tunnels:" )
         ovrViasCount = len( entry.overrideVias )
         for i in range( ovrViasCount ):
            if not entry.indexId and i == 5:
               ovrVskTable.newRow( f"( this Via Set has {ovrViasCount} tids only "
                                   f"showing 5 )" )
               break
            ovrVskTable.newRow( entry.overrideVias[ i ] )

         table.newRow( defVskTable.output(), '', ovrVskTable.output() )

      print( table.output() )

class CbfMplsVrfModel( Model ):
   vrfs = Dict( keyType=str, valueType=CbfMplsModel,
                help='Per VRF Traffic Engineering MPLS CBF override LfibVsks \
                     entries' )

   def render( self ):
      for vrfSummary in self.vrfs.values():
         vrfSummary.render()
