#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import itertools

import Arnet
from CliModel import Bool
from CliModel import Dict
from CliModel import Enum
from CliModel import Int
from CliModel import Model
from CliModel import Str
import CliPlugin.IntfCli as IntfCli # pylint: disable=consider-using-from-import
from IntfModels import Interface
import TableOutput

def thresholdTypeToUnits( thresholdType ):
   if thresholdType == 'packetsPerSecond':
      return 'pps'
   elif thresholdType == 'percentage':
      return '%'
   elif thresholdType == 'bitsPerSecond':
      return 'bps'
   else:
      assert False, "Unknown thresholdType" 
      return None

BPS_UNITS = [ "bps", "Kbps", "Mbps", "Gbps" ]
   
def convertToHighestBps( level ):
   """Helper function to convert a bps rate into the largest unit possible.

   Parameters
   ----------
   level : int
      The bps rate from Threshold.

   Returns
   -------
   rate : int
      The rate.
   unit : str
      The unit for the rate above.
   """
   assert level >= 0
   # Multiply by 1k to simplify logic in the loop.
   rate = level * 1000
   # initialize unit to make pylint happy
   unit = ''

   for unit in BPS_UNITS:
      
      rate = int( rate / 1000 )
      if rate % 1000 != 0:
         break
   return ( rate, unit )

trafficTypes = [ 'all', 'unknown-unicast', 'broadcast', 'multicast' ]
cpuPolicingTypes = [ 'yes', 'no' ]

class StormControlType( Model ):
   level = Int( help='Storm control level' )
   thresholdType = Enum( values=( 'percentage', 'bitsPerSecond',
                                  'packetsPerSecond' ),
                         help='Type of threshold applied' )
   rate = Int( help='Effective allowed broadcast rate in bps' )
   drop = Int( help='Packets dropped by storm control', optional=True )
   dropOctets = Int( help='Octets dropped by storm control', optional=True )
   dormant = Bool( help='Rule is superseded by all traffic' )
      
   def toPrintableTuple( self ):
      if self.thresholdType == 'percentage':
         # percentage is stored as an hundredths of a percent
         return ( ( self.level / 100.0 ),
                  thresholdTypeToUnits( self.thresholdType ), str( self.rate ),
                  self.dormant, self.drop if self.drop is not None else '',
                  self.dropOctets if self.dropOctets is not None else '')
      elif self.thresholdType == 'bitsPerSecond':
         rate, units = convertToHighestBps( self.level )

         return ( rate, units, str( self.rate ), self.dormant,
                  self.drop if self.drop is not None else '',
                  self.dropOctets if self.dropOctets is not None else '' )
      else:
         return ( self.level, thresholdTypeToUnits( self.thresholdType ),
                  str( self.rate ),
                  self.dormant,
                  self.drop if self.drop is not None else '',
                  self.dropOctets if self.dropOctets is not None else '' ) 

class IntfStormControlStatus( Model ):
   
   trafficTypes = Dict(
      help='Applied storm control setting for a given traffic type',
      keyType=str, valueType=StormControlType )
   active = Bool( help='Whether storm control has been applied' )
   reason = Str( help='Reason for the status' )
   errdisabled = Bool( help='Whether the interface has been errdisabled' )

class StormControlStatus( Model ):
      
   aggregateTrafficClasses = Dict(
         help='A mapping of Aggregate traffic class to storm control status',
         keyType=int, valueType=StormControlType, optional=True )

   interfaces = Dict(
      help='Per-interface storm control status',
      keyType=Interface, valueType=IntfStormControlStatus )
   
   _errdisableEnabled = Bool(
         help='Whether the errdisable column is shown for storm-control.' )
   
   def errdisableEnabledIs( self, enabled ):
      self._errdisableEnabled = enabled

   def render( self ):
      tcList = sorted( set( self.aggregateTrafficClasses ) )
      if tcList:
         print( 'Aggregate: ' )
         aggregateTcHdrFormat = '%-10s %-18s %-12s %-5s'
         print( aggregateTcHdrFormat % ( 'Type', 'Traffic Class', 'Level',
                                         'Units' ) )
         for tc in tcList:
            aggregateStatus = self.aggregateTrafficClasses[ tc ]
            print( aggregateTcHdrFormat % (
               'BUM', tc, aggregateStatus.level,
               thresholdTypeToUnits( aggregateStatus.thresholdType ) ) )
         print( '' )

      intfs = Arnet.sortIntf( set( self.interfaces ) )
      if not( intfs or tcList ):
         print( 'No storm control configured on any interface' )

      if not intfs:
         return

      cntSupported = any( any( traffic.drop is not None
                          for traffic in intfStatus.trafficTypes.values() )
                          for intfStatus in self.interfaces.values() )
      dropOctetsSupported = any( any( traffic.dropOctets is not None
                          for traffic in intfStatus.trafficTypes.values() )
                          for intfStatus in self.interfaces.values() )
      
      errdisableEnabled = self._errdisableEnabled
      tableHeadings = [ 
         ( "Port", 10 ),
         ( "Type", 16 ),
         ( "CPU", 3 ),
         ( "Level", 20 ),
         ( "Units", 5 ),
         ( "Rate(Mbps)", 10 ),
         ( "Status", 8 ),
         ( "DropPkts" if cntSupported else "", 20 ),
         ( "DropOctets" if dropOctetsSupported else "", 21 ),
         ( "Reason", 24 ),
         ( "Errdisabled" if errdisableEnabled else "", 11 ),
      ]

      table = TableOutput.createTable( heading for heading, _ in tableHeadings )

      
      columnFormats = []
      for _, width in tableHeadings:
         columnFormat = TableOutput.Format( justify="left", maxWidth=width,
                                            wrap=True )
         columnFormat.padLimitIs( True )
         columnFormat.noPadLeftIs( True )
         columnFormats.append( columnFormat )
      
      # Justifying drops to right
      columnFormats[ 5 ].justify_ = "right"
      columnFormats[ 7 ].justify_ = "right"
      columnFormats[ 8 ].justify_ = "right"
      table.formatColumns( *columnFormats )

      for i in intfs:
         displayLines = []
         intfStatus = self.interfaces[ i ]
         for trafficType, cpu in itertools.product( trafficTypes, cpuPolicingTypes ):
            key = trafficType + "-cpu" if cpu == "yes" else trafficType
            if key in intfStatus.trafficTypes:
               trafficStatus = intfStatus.trafficTypes[ key ]
               displayLines.append(
                  ( trafficType, cpu ) + trafficStatus.toPrintableTuple() )

         firstTime = True

         for scType, cpu, level, units, rate, dormant, drop, dropOctets in \
            displayLines:
            intfName = IntfCli.Intf.getShortname( i ) if firstTime else ''
            if dormant:
               stormCtrStatus = 'inactive'
               reason = 'storm-control all active'
            else:
               stormCtrStatus = 'active' if intfStatus.active else 'inactive'
               reason = intfStatus.reason
               
            errdisabled = ''
            if errdisableEnabled:
               errdisabled = 'yes' if intfStatus.errdisabled  else 'no'
            table.newRow( intfName, scType, cpu, str( level ), str( units ),
                          str( rate ), stormCtrStatus, drop, dropOctets, reason,
                          errdisabled )

            firstTime = False
      print( table.output() )
               
