# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliModel import Model, Submodel, Dict, Bool, Int, Enum, Float, List
from ArnetModel import Ip6Address
from CliPlugin import IntfCli
from IntfModels import Interface
from IgmpSnoopingModel import FormattedPrinting
import TableOutput

# pylint: disable-msg=unsubscriptable-object

statusEnumToStr = {
      "enabled": "Enabled",
      "disabled": "Disabled",
      "default": "Default",
}

stateEnumToStr = {
      "enabled": "Enabled",
      "disabled": "Disabled",
}

operationalStateEnumToStr = {
      "querierStateQuerier": "Querier",
      "querierStateNonQuerier": "Non-Querier",
      "querierStateInitial": "Pending",
}

filterModeEnumToStr = {
      "filterModeExclude": "EX",
      "filterModeInclude": "IN",
}

class MldSnoopingVlanInfo( Model ):
   mldSnoopingState = Enum( help="State of MLD Snooping",
                            values=list( stateEnumToStr ) )
   maxGroups = Int( default=65534,
         help="Maximum number of multicast groups that can join the VLAN" )
   groupsOverrun = Bool( default=False,
         help="There has been an attempt to create more than "
         "the maximum number of groups" )
   pruningActive = Bool( help="MLD snooping pruning is active", default=False )
   floodingTraffic = Bool( help="Flooding traffic to VLAN", default=True )
   evpnProxyActive = Bool( help="MLD proxying enabled via EVPN", default=False,
                           optional=True )

class MldSnoopingInfo( Model ):
   mldSnoopingState = Enum( help="Global state of MLD Snooping",
                            values=list( stateEnumToStr ), optional=True )
   robustness = Int(
         help="Number of queries sent to age out a port's membership in group",
         optional=True )
   mlagConnectionState = Enum(
      help="MLD Snooping agent MLAG connection state",
      values=( 'n/a', 'initializing', 'connecting', 'connected',
              'disconnecting', 'disconnected', 'failed', 'unknown' ), optional=True )
   mlagMountState = Enum(
      help="MLD Snooping agent MLAG mount state",
      values=( 'n/a', 'unmounted', 'mounting', 'mounted',
              'failed', 'unknown' ), optional=True )
   vlans = Dict( keyType=int, valueType=MldSnoopingVlanInfo,
         help="A mapping of VLAN's ID to its information", optional=True )


   def render( self ):
      if self.mldSnoopingState is None:
         return
      print( "   Global MLD Snooping configuration:" )
      print( "-------------------------------------------" )
      print( "%-30s : %s" % ( 'MLD snooping',
            stateEnumToStr[ self.mldSnoopingState ] ) )
      print( "%-30s : %s" % ( 'Robustness variable', self.robustness ) )
      print( "" )
      if self.mlagConnectionState != 'n/a' or self.mlagMountState != 'n/a':
         print( "    MLD Snooping Agent MLAG Status:" )
         print( "-------------------------------------------" )
         print( "%-30s : %s" % ( 'Connection state', self.mlagConnectionState ) )
         print( "%-30s : %s" % ( 'Mount state', self.mlagMountState ) )
         print( "" )
      for vlan, info in sorted( self.vlans.items() ):
         print( "VLAN", "%s" % vlan, ":" )
         print( "----------" )
         print( "%-30s : %s" % ( 'MLD snooping',
               stateEnumToStr[ info.mldSnoopingState ] ) )
         mldMaxGroupsLimitStr = 'MLD max group limit'
         if info.maxGroups == 65534:
            print( "%-30s : %s" % ( mldMaxGroupsLimitStr, 'No limit set' ) )
         else:
            print( "%-30s : %u" % ( mldMaxGroupsLimitStr,
                  info.maxGroups ) )
         print( "%-30s : %s" % ( 'Recent attempt to exceed limit',
               'Yes' if info.groupsOverrun else 'No' ) )
         print( "%-30s : %s" % ( 'MLD snooping pruning active',
               info.pruningActive ) )
         print( "%-30s : %s" % ( 'Flooding traffic to VLAN',
               info.floodingTraffic ) )
         print( "%-30s : %s" % ( 'Evpn proxy active',
                info.evpnProxyActive ) )


class MldSnoopingCountersInterface( Model ):
   pimPacketsReceived = Int( help="Number of PIM packets received",
         optional=True )
   shortPacketsReceived = Int( help="Number of packets received"
         " with not enough IP payload" )
   nonIpPacketsReceived = Int( help="Number of non IP packets received" )
   badChecksumIpPacketsReceived = Int( help="Number of packets received"
         " for which IP checksum check failed" )
   unknownIpPacketsReceived = Int( help="Number of packets received"
         " with unknown IP Protocol" )
   badChecksumPimPacketsReceived = Int( help="Number of packets received"
         " for which PIM checksum check failed" )
   otherPacketsSent = Int( help="Number of other packets sent", optional=True )

   badChecksumIcmpV6PacketsReceived = Int( help="Number of packages received"
         " for which ICMP v6 checksum check failed" )
   badMldQueryReceived = Int( help="Number of invalid MLD querys received" )
   mldV1QueryReceived = Int( help="Number of MLD v1 queries received",
         optional=True )
   mldV2QueryReceived = Int( help="Number of MLD v2 queries received",
         optional=True )
   badMldV2ReportReceived = Int( help="Number of invalid MLD v2 reports received" )
   mldV2ReportReceived = Int( help="Number of MLD v2 reports received",
         optional=True )
   otherIcmpPacketsReceived = Int( help="Number of other ICMP v6 packets received",
         optional=True )
   mldQuerySend = Int( help="Number of MLD querys sent",
         optional=True )
   mldReportSend = Int( help="Number of MLD reports sent",
         optional=True )

class MldSnoopingCounters( Model ):
   interfaces = Dict( keyType=Interface, valueType=MldSnoopingCountersInterface,
         help="Map Interface name with its counter details" )
   _errorSpecific = Bool( help="Display only error counters" )

   def render( self ):
      if self._errorSpecific:
         errorCounterFormat = '%-10s %-10s %-9s %-9s %-12s %-9s %-9s %-9s %-9s'
         print( errorCounterFormat % ( '', 'Packet',
               '  Packet', 'Bad IP', 'Unknown', 'Bad PIM', 'Bad ICMP',
               'Bad MLD', 'Bad MLD' ) )
         print( errorCounterFormat % ( 'Port', 'Too Short',
               '  Not IP', 'Checksum', 'IP Protocol', 'Checksum', 'Checksum',
               'Query', 'Report' ) )
         print( "-" * 93 )
         for intf, counters in sorted( self.interfaces.items() ):
            print( '%-10s %9s %9s %9s %12s %9s %9s %9s %9s' % (
                  IntfCli.Intf.getShortname( intf ),
                  counters.shortPacketsReceived,
                  counters.nonIpPacketsReceived,
                  counters.badChecksumIpPacketsReceived,
                  counters.unknownIpPacketsReceived,
                  counters.badChecksumPimPacketsReceived,
                  counters.badChecksumIcmpV6PacketsReceived,
                  counters.badMldQueryReceived,
                  counters.badMldV2ReportReceived ) )
         return

      print( "{:^38}|{:>10}".format( 'Input', 'Output' ) )
      print( "Port   Queries Reports  Others  Errors|"
             "Queries Reports  Others" )
      print( "-" * 62 )
      for intf, counters in sorted( self.interfaces.items() ):
         print( '%-6s %7d %7d %7d %7d %7d %7d %7d' % (
            IntfCli.Intf.getShortname( intf ),
            ( counters.mldV1QueryReceived + counters.mldV2QueryReceived ),
            counters.mldV2ReportReceived,
            ( counters.pimPacketsReceived + counters.otherIcmpPacketsReceived ),
            ( counters.shortPacketsReceived + counters.nonIpPacketsReceived +
               counters.badChecksumIpPacketsReceived +
               counters.unknownIpPacketsReceived +
               counters.badChecksumPimPacketsReceived +
               counters.badChecksumIcmpV6PacketsReceived +
               counters.badMldQueryReceived +
               counters.badMldV2ReportReceived ),
            counters.mldQuerySend,
            counters.mldReportSend,
            counters.otherPacketsSent ) )

class MldSnoopingPacketCounters( Model ):
   v1GeneralQueries = Int( default=0, help="Number of v1 general queries" )
   v1GSQueries = Int( default=0, help="Number of v1 group specific queries" )
   v1Reports = Int( default=0, help="Number of v1 reports" )
   v1Dones = Int( default=0, help="Number of v1 dones" )
   v2GeneralQueries = Int( default=0, help="Number of v2 general queries" )
   v2GSQueries = Int( default=0, help="Number of v2 group specific queries" )
   v2GSSQueries = Int( default=0, help="Number of v2 group source specific queries" )
   v2Reports = Int( default=0, help="Number of v2 reports" )
   errorPackets = Int( default=0, help="Number of error packets" )
   otherPackets = Int( default=0, help="Number of other packets" )

   def render( self ):
      print( "  V1 general queries: " + str( self.v1GeneralQueries ) )
      print( "  V1 group specific queries: " + str( self.v1GSQueries ) )
      print( "  V1 reports: " + str( self.v1Reports ) )
      print( "  V1 dones: " + str( self.v1Dones ) )
      print( "  V2 general queries: " + str( self.v2GeneralQueries ) )
      print( "  V2 group specific queries: " + str( self.v2GSQueries ) )
      print( "  V2 group source specific queries: " + str( self.v2GSSQueries ) )
      print( "  V2 reports: " + str( self.v2Reports ) )
      print( "  Error packets: " + str( self.errorPackets ) )
      print( "  Other packets: " + str( self.otherPackets ) )

class MldSnoopingQuerierCounters( Model ):
   txCounters = Submodel( valueType=MldSnoopingPacketCounters,
         help="TX counters" )
   rxCounters = Submodel( valueType=MldSnoopingPacketCounters,
         help="RX Counters" )

   def render( self ):
      print( "TX:" )
      self.txCounters.render()
      print( "RX:" )
      self.rxCounters.render()

class SourceMembership( Model ):
   mldVersion = Enum( help="Version of MLD snooping querier",
                      values=( 'v1', 'v2', 'unknown' ) )

   def mldVerToText( self ):
      return self.mldVersion.replace( 'unknown', '-' )

class GroupMembership( SourceMembership ):
   filterMode = Enum( help="Filter mode of multicast group",
                      values=( 'filterModeExclude', 'filterModeInclude' ) )
   sourceMemberships = Dict( keyType=Ip6Address, valueType=SourceMembership,
                             help="Mapping source address to the source info",
                             optional=True )

class VlanMembership( Model ):
   groupMemberships = Dict( keyType=Ip6Address, valueType=GroupMembership,
                            help="Mapping group address to the group info" )

class MldSnoopingMembership( Model ):
   vlanMemberships = Dict( keyType=int, valueType=VlanMembership,
                           help="Mapping vlan id to the membership info",
                           optional=True )
   _groupSpecific = Bool( default=False,
         help="Display information in group specific format" )

   def render( self ):
      table = None
      leftFmt = TableOutput.Format( justify="left" )
      leftFmt.noPadLeftIs( True )
      headers = ( "Group", "Mode", "Version", "Number of Sources" )
      if self._groupSpecific:
         headers = ( "Source", "Version" )
      table = TableOutput.createTable( headers )

      for vlan, vlanMembership in sorted( self.vlanMemberships.items() ):
         if not self._groupSpecific:
            table.startRow()
            table.newFormattedCell( "Memberships for VLAN " +
               str( vlan ), nCols=4, format=leftFmt )

         for groupAddress, groupMembership in \
               sorted( vlanMembership.groupMemberships.items() ):
            if self._groupSpecific:
               table.startRow()
               table.newFormattedCell( "Memberships for VLAN " +
                  str( vlan ) + "   Group " +
                  groupAddress, nCols=2, format=leftFmt )
               for sourceAddress, sourceMembership in \
                     sorted( groupMembership.sourceMemberships.items() ):
                  table.newRow( sourceAddress, sourceMembership.mldVerToText() )
            else:
               version = groupMembership.mldVerToText()
               table.newRow( groupAddress,
                             filterModeEnumToStr[ groupMembership.filterMode ],
                             version, len( groupMembership.sourceMemberships ) )

      print( table.output() )

class MldSnoopingVlanIntfs( Model ):
   interfaces = List( valueType=Interface, help="VLAN interfaces info" )

class MldSnoopingReportFlooding( Model ):
   vlans = Dict( keyType=int, valueType=MldSnoopingVlanIntfs,
                          help="Mapping VLAN ID to its interface-ports" )

   def render( self ):
      if not self.vlans:
         return
      print( "Vlan    Interface-ports" )
      print( f'{"-"*4} {" "*2} {"-"*64}' )
      for vlanId, intfs in sorted( self.vlans.items() ):
         fp = FormattedPrinting( margin=8, startOffset=8, termwidth=75 )
         print( f"{vlanId:<7}", end=' ' )
         for intf in sorted( intfs.interfaces ):
            fp.add( IntfCli.Intf.getShortname( intf ) )
         fp.display()

class MldSnoopingVlanQuerier( Model ):
   querierAddress = Ip6Address(
         help="Address of MLD querier in the VLAN" )
   mldVersion = Enum( help="Version of MLD snooping querier in the vlan",
         values=( 'v1', 'v2', 'unknown' ) )
   querierInterface = Interface( help="Interface where the querier is located" )
   queryResponseInterval = Float( help="Effective maximum period a recipient"
   " can wait before responding with a membership report in seconds" )
   operationalState = Enum( help="Operational state of querier in the Vlan",
         values=( 'querierStateInitial', 'querierStateQuerier',
                  'querierStateNonQuerier' ), optional=True )
   queryInterval = Float( help="Period between MLD Membership Query messages",
                          optional=True )
   listenerTimeout = Float( help="Time before the querier decides there are no more"
            "listeners for a multicast address in the VLAN", optional=True )
   adminState = Enum( help="State of MLD querier in the Vlan",
         values=( 'enabled', 'disabled' ), optional=True )

class MldSnoopingQuerier( Model ):
   vlans = Dict( keyType=int, valueType=MldSnoopingVlanQuerier,
         help="Mapping vlan Id to the vlan querier info" )
   _vlanSpecific = Bool( default=False,
         help="Display information in vlan specific format if True" )

   def render( self ):
      if not self.vlans:
         return
      elif self._vlanSpecific:
         vlanId = next( iter( self.vlans ) )
         mldVersion = self.vlans[ vlanId ].mldVersion
         print( '%-20s : %s\n%-20s : %s\n%-20s : %s\n%-20s : %s' % (
               'IP Address', self.vlans[ vlanId ].querierAddress,
               'MLD Version', '-' if mldVersion == 'unknown' else mldVersion,
               'Port', IntfCli.Intf.getShortname(
                  self.vlans[ vlanId ].querierInterface ),
               'Max response time', self.vlans[ vlanId ].queryResponseInterval ) )
         return
      print( "%-5s %-24s %-8s %s" % (
            'Vlan', 'IP Address', 'Version', 'Port' ) )
      print( "-" * 44 )
      for vlanId, info in sorted( self.vlans.items() ):
         print( '%-5s %-24s %-8s %s' % ( vlanId, info.querierAddress,
               '-' if info.mldVersion == 'unknown' else info.mldVersion,
               IntfCli.Intf.getShortname( info.querierInterface ) ) )

class MldSnoopingQuerierDetail( Model ):
   vlans = Dict( keyType=int, valueType=MldSnoopingVlanQuerier,
         help="Mapping vlan Id to the vlan querier info" )

   def render( self ):
      for vlanId, info in sorted( self.vlans.items() ):
         print( "VLAN " + str( vlanId ) )
         print( "IP address:", info.querierAddress )
         version = '-' if info.mldVersion == 'unknown' else info.mldVersion
         print( "Version:", version )
         interface = IntfCli.Intf.getShortname( info.querierInterface )
         print( "Interface:", interface )
         if info.adminState:
            print( "Admin state:", info.adminState )
         if info.queryInterval:
            print( "Query interval:", info.queryInterval, "seconds" )
         if info.queryResponseInterval:
            print( "Response time:", info.queryResponseInterval, "seconds" )
         if info.listenerTimeout:
            print( "Listener timeout:", info.listenerTimeout, "seconds" )
         print( "Operational state:",
            operationalStateEnumToStr.get( info.operationalState, "Disabled" ) )
         print()
