#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import itertools

from TableOutput import Format, createTable, FormattedCell
import Tac

class Table:
   fields = []

   def createHeader( self, title ):
      hdrFormat = Format( justify="left" )
      return FormattedCell( nCols=1, content=title, format=hdrFormat )

   def createTable( self ):
      hdr = [ ( self.createHeader( f ), "l" ) for f in self.fields ]
      return createTable( hdr, indent=2 )

class McsSender( Table ):
   fields = [
         "Flow",
         "Device",
         "Interface",
         "Bandwidth",
         "Label",
         "Pairflow",
         "DSCP",
         "TC",
         "ApplyPolicy"
         ]

   def __init__( self, sender ):
      self.table = self.createTable()
      self.sender = sender
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.sender:
         return

      for k, v in self.sender.items():
         self.table.newRow(
               f"{k.source} {k.group}",
               f"{v.senderId.device.ethAddr}",
               f"{v.senderId.intfId}",
               f"{v.bwInKbps}",
               f"{v.label}",
               f"{v.pairFlow.source} {v.pairFlow.group}",
               f"{v.dscp}",
               f"{v.tc}",
               f"{v.applyPolicy}" )

   def printTable( self ):
      print( self.table.output() )

class McsReceiver( Table ):
   fields = [
         "Flow",
         "Device",
         "Interface",
         "transactionId",
         "trackingId",
         "impactedReason"
         ]

   def __init__( self, receiver ):
      self.receiver = receiver
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.receiver:
         return

      for k, v in self.receiver.items():
         self.table.newRow(
               f"{k.sg.source} {k.sg.group}",
               f"{k.device.ethAddr}",
               f"{k.intfId}",
               f"{v.transactionId}",
               f"{v.trackingId}",
               f"{v.impactedReason}" )

   def printTable( self ):
      print( self.table.output() )

class McsFlow( Table ):
   fields = [
         "Flow",
         "Device",
         "iif",
         "oif",
         ]

   def __init__( self, flowProgrammed ):
      self.flowProgrammed = flowProgrammed
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.flowProgrammed:
         return

      for flow in self.flowProgrammed.values():
         for k in flow.route:
            sg = [ f'{k.sg.source} {k.sg.group}' ]
            device = [ k.device.sysName ]
            iif = [ k.iif ]
            oifs = list( k.oif )
            for f, d, i, o in itertools.zip_longest(
                  sg, device, iif, oifs ):
               self.table.newRow( f"{f}", f"{d}", f"{i}", f"{o}" )

   def printTable( self ):
      print( self.table.output() )

class McsDevice( Table ):
   fields = [
         "SystemID",
         "Hostname",
         "Description",
         ]

   def __init__( self, mcsDevice ):
      self.mcsDevice = mcsDevice
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.mcsDevice:
         return

      for v in self.mcsDevice.values():
         self.table.newRow(
               f"{v.ethAddr}",
               f"{v.sysName}",
               f"{v.sysDesc}" )

   def printTable( self ):
      print( self.table.output() )

class McsEndpoint( Table ):
   fields = [
         "Device",
         "Interface",
         "Direction",
         "totalBandwidth",
         "availableBandwidth",
         "usedBandwidth",
         "Description",
         ]

   def __init__( self, endpoint ):
      self.mcsEndpoint = endpoint
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.mcsEndpoint:
         return

      for k in self.mcsEndpoint:
         self.table.newRow(
               f"{k.device.sysName}",
               f"{k.intfId}",
               f"{k.direction}",
               f"{k.bw.totalInKbps}",
               f"{k.bw.availableInKbps}",
               f"{k.bw.usedInKbps}" )

   def printTable( self ):
      print( self.table.output() )

class McsNetworkLink( Table ):
   fields = [
         "Device",
         "Interface",
         "Direction",
         "totalBandwidth",
         "availableBandwidth",
         "usedBandwidth",
         "Peer Device",
         "Peer Interface",
         "Peer Direction",
         "Peer totalBandwidth",
         "Peer availableBandwidth",
         "Peer usedBandwidth",
         ]

   def __init__( self, network ):
      self.mcsNetworkLink = network
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.mcsNetworkLink:
         return

      for k in self.mcsNetworkLink:
         self.table.newRow(
               f"{k.endpointA.device.sysName}",
               f"{k.endpointA.intfId}",
               f"{k.endpointA.direction}",
               f"{k.endpointA.bw.totalInKbps}",
               f"{k.endpointA.bw.availableInKbps}",
               f"{k.endpointA.bw.usedInKbps}",
               f"{k.endpointB.device.sysName}",
               f"{k.endpointB.intfId}",
               f"{k.endpointB.direction}",
               f"{k.endpointB.bw.totalInKbps}",
               f"{k.endpointB.bw.availableInKbps}",
               f"{k.endpointB.bw.usedInKbps}" )

   def printTable( self ):
      print( self.table.output() )

class McsActiveFlows( Table ):
   fields = [
         "Referrar",
         "Flow",
         "Device",
         "iif",
         "oif",
         "IGMP Snooping",
         "In HW",
         ]

   def __init__( self, activeflows ):
      self.mcsActiveFlows = activeflows
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.mcsActiveFlows:
         return

      for k, v in self.mcsActiveFlows.items():
         ref = [ k.referrar ]
         dev = [ k.mcsDevice.sysName ]
         sg = [ f'{k.sg.source} {k.sg.group}' ]
         iif = [ v.iif ]
         oifs = list( v.oif )
         igmps = [ ( ik, oif ) for ik, iv in (
            v.igmpSnooping.items() for oif in iv.interface )
            ]
         inHw = [ v.inHw ]
         for ref, sg, dev, iif, oif, igmp, hw in (
               itertools.zip_longest( ref, sg, dev, iif, oifs, igmps, inHw ) ):
            r = ref if ref else ''
            s = sg if sg else ''
            d = dev if dev else ''
            i = iif if iif else ''
            o = oif if oif else ''
            g = igmp if igmp else ''
            h = hw if hw else ''
            self.table.newRow(
                  f"{r}", f"{s}", f"{d}", f"{i}", f"{o}", f"{g}", f"{h}" )

   def printTable( self ):
      print( self.table.output() )

class McsStatus( Table ):
   fields = [
         "Name",
         "Code",
         "Message",
         ]

   def __init__( self, status ):
      self.mcsStatus = status
      self.table = self.createTable()
      self.addRows()
      self.printTable()

   def addRows( self ):
      if not self.mcsStatus:
         return

      allStatus = [ ( 'Redis', self.mcsStatus.status.redisStatus ),
                      ( 'HA', self.mcsStatus.status.haStatus ),
                      ( 'Agent', self.mcsStatus.status.agentStatus ),
                      ( 'API', self.mcsStatus.status.apiStatus ),
                      ( 'Topology', self.mcsStatus.status.topology ),
                      ]
      for nK, status in allStatus:
         rcodes = list( status )
         name = [ nK ]
         msgs = []
         for c in rcodes:
            if c in self.mcsStatus.msgCode:
               msgs.append( self.mcsStatus.msgCode[ c ] )
         for n, c, m in itertools.zip_longest( name, rcodes, msgs ):
            self.table.newRow( n, f"{c}", f"{m}" )

   def printTable( self ):
      print( self.table.output() )

class McsAgentStatusRender:
   def __init__( self, status=None ):
      self.mcsAgentStatus = status
      self.inactiveSender_ = self.mcsAgentStatus.inactiveSender if status else ''
      self.activeSender_ = self.mcsAgentStatus.activeSender if status else ''
      self.inactiveReceiver_ = self.mcsAgentStatus.inactiveReceiver if status else ''
      self.activeReceiver_ = self.mcsAgentStatus.activeReceiver if status else ''
      self.impactedReceiver_ = self.mcsAgentStatus.impactedReceiver if status else ''
      self.failedReceiver_ = self.mcsAgentStatus.failedReceiver if status else ''
      self.flowProgrammed_ = self.mcsAgentStatus.flowProgrammed if status else ''
      self.mcsClientDevice_ = self.mcsAgentStatus.mcsClientDevice if status else ''
      self.maintenanceDevice_ = (
            self.mcsAgentStatus.maintenanceDevice if status else '' )
      self.endPoint_ = self.mcsAgentStatus.endPoint if status else ''
      self.networkLink_ = self.mcsAgentStatus.networkLink if status else ''
      self.activeFlows_ = self.mcsAgentStatus.activeFlows if status else ''
      self.status_ = self.mcsAgentStatus if status else ''

   def inactiveSender( self ):
      print( "InactiveSender:" )
      McsSender( self.inactiveSender_ )

   def activeSender( self ):
      print( "ActiveSender:" )
      McsSender( self.activeSender_ )

   def inactiveReceiver( self ):
      print( "InactiveReceiver:" )
      McsReceiver( self.inactiveReceiver_ )

   def activeReceiver( self ):
      print( "ActiveReceiver:" )
      McsReceiver( self.activeReceiver_ )

   def impactedReceiver( self ):
      print( "ImpactedReceiver:" )
      McsReceiver( self.impactedReceiver_ )

   def failedReceiver( self ):
      print( "FailedReceiver:" )
      McsReceiver( self.failedReceiver_ )

   def flowProgrammed( self ):
      print( "FlowProgrammed:" )
      McsFlow( self.flowProgrammed_ )

   def mcsClientDevice( self ):
      print( "McsClientDevice:" )
      McsDevice( self.mcsClientDevice_ )

   def maintenanceDevice( self ):
      print( "MaintenanceDevice:" )
      McsDevice( self.maintenanceDevice_ )

   def endPoint( self ):
      print( "Endpoint:" )
      McsEndpoint( self.endPoint_ )

   def networkLink( self ):
      print( "NetworkLink:" )
      McsNetworkLink( self.networkLink_ )

   def status( self ):
      print( "Status:" )
      McsStatus( self.status_ )

   def activeFlows( self ):
      print( "ActiveFlows:" )
      McsActiveFlows( self.activeFlows_ )

if __name__ == "__main__":
   render = McsAgentStatusRender()
   render.inactiveSender()
   render.activeSender()
   render.inactiveReceiver()
   render.activeReceiver()
   render.impactedReceiver()
   render.failedReceiver()
   render.flowProgrammed()
   render.mcsClientDevice()
   render.maintenanceDevice()
   render.endPoint()
   render.networkLink()
   render.activeFlows()
   render.status()
