# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from textwrap import wrap
import Arnet
import Ark
import Tac
import TableOutput
from CliModel import Model, Submodel, Bool, Enum, Str, List, Dict, Int, Float
from ArnetModel import IpGenericAddrAndPort
from IntfModels import Interface
from Toggles.McastDnsToggleLib import toggleMcastDnsFloodSuppressionEnabled

class MdnsCounter( Model ):
   txGenErrors = Int( help="MDNS packet generation errors" )
   txDsoGenErrors = Int( help="DSO packet generation errors" )
   txPamErrors = Int( help="Packet transmission errors" )
   rxParseErrors = Int( help="Discarded MDNS packets" )
   rxDsoParseErrors = Int( help="Discarded DSO packets" )
   rxPamErrors = Int( help="Discarded MDNS packets received on the wrong interface" )
   txUdpPkts = Int( help="Sent MDNS packets" )
   txTcpPkts = Int( help="Sent DSO packets" )
   rxUdpPkts = Int( help="Received MDNS packets" )
   rxTcpPkts = Int( help="Received DSO packets" )
   ignoredNsec = Int( help="Ignored NSEC records" )
   ignoredKA = Int( help="Ignored Known Answer records" )
   ignoredProbing = Int( help="Ignored probing queries" )
   ignoredQuestions = Int( help="Ignored One-shot queries" )
   suppressedResponseRecords = Int( help="Suppressed MDNS responses" )

   def render( self ):
      for attrName, attrInfo in sorted( self.__attributes__.items(),
                                        key=lambda x: x[ 1 ].help ):
         description = attrInfo.help.rstrip( '.' )
         counter = getattr( self, attrName )
         print( f"{description}: {str( counter )}" )

class MdnsServerGatewayStatus( Model ):
   client = Submodel( valueType=IpGenericAddrAndPort, help="Remote gateway" )
   connectionStatus = Str( help="Connection Status" )

class MdnsStatus( Model ):
   mdnsEnabled = Bool( default=False,
         help="mDNS agent is enabled" )
   mdnsRunning = Bool( default=False, help="mDNS agent is running" )
   if toggleMcastDnsFloodSuppressionEnabled():
      floodSuppression =  \
            Bool( default=False, help="Flooding suppression is enabled" )
   dsoEnabled = Bool( default=False, help="DSO server is enabled" )
   dsoRunning = Bool( default=False, help="DSO server is running" )
   tcpClient = List( valueType=MdnsServerGatewayStatus,
         help="Connections to DSO servers" )
   tcpServerClient = List( valueType=IpGenericAddrAndPort,
         help="Connection from DSO clients" )

   def render( self ):
      print( "{} {}".format(
            "mDNS is",
            "running" if self.mdnsRunning else
            "enabled" if self.mdnsEnabled else "disabled" ) )

      if toggleMcastDnsFloodSuppressionEnabled():
         print( "{} {}".format(
               "Flooding suppression is",
               "enabled" if self.floodSuppression else "disabled" ) )

      print( "{} {}".format(
            "DSO server is",
            "running" if self.dsoRunning else
            "enabled" if self.dsoEnabled else "disabled",
            ) )

      if not self.mdnsRunning:
         return
      print( "\nGateway DSO connections" )
      print( "%-16s %-10s %s" % ( 'Address', 'Port', 'Status' ) )
      for status in sorted( self.tcpClient, key=lambda x: x.client.ip.sortKey ):
         print( '%-16s %-10s %s' % (
               status.client.ip, status.client.port, status.connectionStatus ) )

      print( "\nDSO client connections" )
      print( "%-16s %s" % ( 'Address', 'Port' ) )
      for entry in sorted( self.tcpServerClient, key=lambda x: x.ip.sortKey ):
         print( '%-16s %s' % ( entry.ip, entry.port ) )

class MdnsLinkEntry( Model ):
   family = Enum( help='Address family',
                  values=Tac.Type( 'Arnet::AddressFamily' ).attributes )
   linkId = Str( help='Configured link name' )
   defaultTag = Str( help='Configured default location tag' )
   status = Enum( values=( "active", "inactive" ), help='Interface status' )

class MdnsLinkList( Model ):
   # Same interface could have both ipv4 and ipv6 address family in the future
   intfInfos = List( valueType=MdnsLinkEntry, help='Interfaces information' )

class MdnsLink( Model ):
   interfaces = Dict( keyType=Interface, valueType=MdnsLinkList,
                      help="mDNS enabled interfaces indexed by interface name" )

   def render( self ):
      headingEntries = [ 'Interface', 'Address Family', 'Link ID', 'Default Tag',
                         'Status' ]
      table = TableOutput.createTable( headingEntries, indent=0 )
      columnFormat = TableOutput.Format( justify='left', maxWidth=50, wrap=True )
      table.formatColumns( *( [ columnFormat ] * len( headingEntries ) ) )
      for intf in Arnet.sortIntf( self.interfaces ):
         interface = self.interfaces[ intf ]
         for entry in sorted( interface.intfInfos, key=lambda x: x.family ):
            table.newRow( intf, entry.family, entry.linkId, entry.defaultTag,
                          entry.status )
      print( table.output() )

class MdnsServiceEntry( Model ):
   interface = Interface( help='Interface the service record is learned from',
               optional=True )
   link = Str( help='Link name the service record is learned from', optional=True )
   locations = List( valueType=str, help='Location tags of the service record' )
   if toggleMcastDnsFloodSuppressionEnabled():
      conflicts = Bool( help='Existence of conflicting records in the network',
                        optional=True )

class MdnsServiceList( Model ):
   # Store services from different links
   srvInfos = List( valueType=MdnsServiceEntry, help='MDNS services' )

class MdnsServiceRecordServiceRule( Model ):
   # services is per origDname
   services = Dict( keyType=str, valueType=MdnsServiceList,
                    help="A mapping of service name to its information" )
   queryLinks = List( valueType=Interface, help="Configured query links" )
   responseLinks = List( valueType=str, help="Configured response links" )
   responseIntfs = List( valueType=Interface, help="Configured response interfaces" )
   srvTypes = List( valueType=str, help="Discovering service types" )
   def render( self ):

      # print query interfaces
      queryStr = ', '.join( Arnet.sortIntf( self.queryLinks ) )
      queryStr = 'Query link: ' + queryStr
      print( queryStr )

      # print response links
      respStr = ', '.join( Arnet.sortIntf( self.responseLinks ) )
      respStr = 'Response link: ' + respStr
      print( respStr )

      # print response interfaces
      respStr = ', '.join( Arnet.sortIntf( self.responseIntfs ) )
      respStr = 'Response interface: ' + respStr
      print( respStr )

      # print discovering service types
      srvTypeStr = ', '.join( sorted( self.srvTypes ) )
      srvTypeStr = 'Service types: ' + srvTypeStr
      print( srvTypeStr )

      # print service types
      if toggleMcastDnsFloodSuppressionEnabled():
         headingEntries = [ 'Service Name', 'Interface/Link', 'Location', 'Status' ]
      else:
         headingEntries = [ 'Service Name', 'Interface/Link', 'Location' ]
      table = TableOutput.createTable( headingEntries, indent=0 )
      columnFormat = TableOutput.Format( justify='left', maxWidth=50, wrap=True )
      table.formatColumns( *( [ columnFormat ] * len( headingEntries ) ) )
      for key in Arnet.sortIntf( self.services ):
         srvInfos = self.services[ key ].srvInfos
         # service is sorted by interface first ( service without interface is sorted
         # later ), and then by link ( service without link name is sorted earlier )
         for service in sorted( srvInfos, key=lambda x: (
               Arnet.intfNameKey( x.interface.stringValue )
               if x.interface else [ '\xff' ],
               Arnet.intfNameKey( x.link ) if x.link else [ '\x00' ] ) ):
            intfStr = service.interface.stringValue if service.interface else ''
            linkStr = service.link if service.link else ''
            intfLinkStr = '/'.join( _f for _f in ( intfStr, linkStr ) if _f )
            if toggleMcastDnsFloodSuppressionEnabled():
               conflictsStr = 'conflicts' if service.conflicts else ''
            for location in Arnet.sortIntf( service.locations ):
               if toggleMcastDnsFloodSuppressionEnabled():
                  table.newRow( key, intfLinkStr, location, conflictsStr )
               else:
                  table.newRow( key, intfLinkStr, location )
      print( table.output() )

class Record( Model ):
   recordBytes = List( valueType=int, help='Raw bytes of records' )

class MdnsUnknownList( Model ):
   unknowns = List( valueType=Record, help='Unknown records' )

class MdnsServiceHostInfo( Model ):
   priority = Int( help="SRV record priority", optional=True )
   weight = Int( help="SRV record weight", optional=True )
   port = Int( help="SRV record port", optional=True )
   v4Addrs = List( valueType=str, help="Information of the A record", optional=True )
   v6Addrs = List( valueType=str, help="Information of the AAAA record",
                   optional=True )
   txts = List( valueType=Record, help="TXT records", optional=True )
   unknownTypes = Dict( keyType=int, valueType=MdnsUnknownList,
                        help="A mapping of unknown types to records", optional=True )
   if toggleMcastDnsFloodSuppressionEnabled():
      conflicts = Bool( help='Existence of conflicting records in the network',
                        optional=True )


class MdnsServiceInfo( Model ):
   link = Str( help='Name of the link the service record is learned from',
               optional=True )
   interface = Interface( help='Interface the service record is learned from',
               optional=True )
   srvTxts = List( valueType=Record, help='Service TXT records',
                   optional=True )
   hosts = Dict( keyType=str, valueType=MdnsServiceHostInfo,
                 help="A mapping of Host name to its meta information",
                 optional=True )

class MdnsServiceRecordServiceName( Model ):
   name = Str( default='', help="Service name" )
   services = List( valueType=MdnsServiceInfo, help="Service information" )
   _detail = Bool( help="Display detailed information" )
   if toggleMcastDnsFloodSuppressionEnabled():
      conflicts = Bool( help='Existence of conflicting records in the network',
                        optional=True )

   # wrap raw data and limit it to be within 3 lines.
   # If the raw bytes are longer than 3 lines, a trailing ' ...(# bytes)'
   # is added and the raw data is trancated to keep 3 line limitation which include
   # the trailing string
   def printWrappedRawData( self, raw, title, indent, width ):
      endStr = ' ...(' + str( len( raw ) ) + ' bytes)' # len <= 17
      rawStr = " ".join( str( e ) for e in raw )
      wrapped = wrap( title + rawStr, width=width, subsequent_indent=indent )
      # deal wth the third line
      if len( wrapped ) >= 3:
         wrapped = wrapped[ : 3 ]
         # cut integers from the third line so that the third line won't wrap again
         # by adding the endStr
         cutWidth = width - len( endStr )
         if len( wrapped[ 2 ] ) > cutWidth:
            # use wrap so that the integer won't be cut off half
            wrapped[ 2 ] = wrap( wrapped[ 2 ], width=cutWidth )[ 0 ] + endStr
      for w in wrapped:
         print( w )

   def render( self ):
      render_indent = "  "
      render_conflicts = "(*)"
      conflictsNote = False
      if not self.name or not self.services:
         return
      serviceConflictsStr = ""
      if toggleMcastDnsFloodSuppressionEnabled():
         if self.conflicts:
            serviceConflictsStr = render_conflicts
            conflictsNote = True
      if not self._detail:
         print( 'Service Name:', self.name + serviceConflictsStr )
         # sort services by interface
         # use '\xff' to sort remote links after local interfaces
         firstService = True # skip printing an empty line for the first service
         for service in sorted( self.services,
               key=lambda x: Arnet.intfNameKey( x.interface.stringValue )
               if x.interface else [ '\xff' ] ):
            if firstService:
               firstService = False
            else:
               print()
            link = service.link
            interface = service.interface
            print( '{}{}: {}'.format( render_indent, 'Interface',
                               interface.stringValue if interface else '--' ) )
            if link:
               print( '{}{}: {}'.format( render_indent * 2, 'Link', link ) )

            hostStr = ''
            for host in Arnet.sortIntf( service.hosts ):
               hostConflictsStr = ""
               hostInfo = service.hosts[ host ]
               if toggleMcastDnsFloodSuppressionEnabled():
                  if hostInfo.conflicts:
                     hostConflictsStr = render_conflicts
                     conflictsNote = True
               host += hostConflictsStr
               hostStr += ', ' + host if hostStr else host
            if hostStr:
               print( '{}{}: {}'.format( render_indent, 'Host', hostStr ) )
         if conflictsNote:
            print( '(*) Conflicts with one or more other service records.' )
         return

      firstService = True # skip printing an empty line for the first service
      # detailed output
      for service in sorted( self.services,
            key=lambda x: Arnet.intfNameKey( x.interface.stringValue )
            if x.interface else [ '\xff' ] ):
         if firstService:
            firstService = False
         else:
            print()
         print( 'Service Name:', self.name + serviceConflictsStr )

         title = 'TXT record: '
         indent = ' ' * len( title )
         width = 60 + len( indent )
         # _TypedList could not be used as a key in sorted
         for txt in sorted( service.srvTxts,
                            # pylint: disable-next=unnecessary-comprehension
                            key=lambda x: [ e for e in x.recordBytes ] ):
            self.printWrappedRawData( txt.recordBytes, title, indent, width )

         link = service.link
         interface = service.interface
         print( 'Interface:', interface.stringValue if interface else '--' )
         if link:
            print( '{}{}: {}'.format( render_indent, 'Link', link ) )

         for host in Arnet.sortIntf( service.hosts ):
            hostConflictsStr = ""
            hostInfo = service.hosts[ host ]
            if toggleMcastDnsFloodSuppressionEnabled():
               if hostInfo.conflicts:
                  hostConflictsStr = render_conflicts
                  conflictsNote = True
            host += hostConflictsStr
            print()
            print( '{}{}: {}'.format( render_indent, 'Host', host ) )
            print( '{}{}: {}'.format( render_indent,
                  'Priority', hostInfo.priority ) )
            print( '{}{}: {}'.format( render_indent, 'Weight', hostInfo.weight ) )
            print( '{}{}: {}'.format( render_indent, 'Port', hostInfo.port ) )

            v4AddrStr = ', '.join( addr for addr in sorted( hostInfo.v4Addrs ) )
            if v4AddrStr:
               print( '{}{}: {}'.format( render_indent, 'V4 Address', v4AddrStr ) )

            v6AddrStr = ', '.join( addr for addr in sorted( hostInfo.v6Addrs ) )
            if v6AddrStr:
               print( '{}{}: {}'.format( render_indent, 'V6 Address', v6AddrStr ) )

            title = render_indent + 'TXT record: '
            indent = ' ' * len( title )
            width = 60 + len( indent )
            for txt in sorted( hostInfo.txts,
                               # pylint: disable-next=unnecessary-comprehension
                               key=lambda x: [ e for e in x.recordBytes ] ):
               self.printWrappedRawData( txt.recordBytes, title, indent, width )

            for unknownType, unknownList in \
                  sorted( hostInfo.unknownTypes.items() ):
               title = render_indent + 'Record (Type %s): ' % unknownType
               indent = ' ' * len( title )
               width = 60 + len( indent )
               # pylint: disable=unnecessary-comprehension
               for unknown in sorted( unknownList.unknowns,
                                      key=lambda x: [ e for e in x.recordBytes ] ):
                  self.printWrappedRawData(
                        unknown.recordBytes, title, indent, width )
               # pylint: enable=unnecessary-comprehension
      if conflictsNote:
         print( '(*) Conflicts with one or more other service records.' )

class MdnsServiceTypeEntry( Model ):
   family = Enum( help='Address family',
                  values=Tac.Type( 'Arnet::AddressFamily' ).attributes )
   serviceType = Str( help='Service type' )
   timestamp = Float( help='Timestamp of when the service type was last observed' )

class MdnsServiceTypeList( Model ):
   # Same interface could have both ipv4 and ipv6 address family in the future
   intfInfos = List( valueType=MdnsServiceTypeEntry, help='Interfaces information' )

class MdnsServiceType( Model ):
   serviceTypes = Dict( keyType=Interface, valueType=MdnsServiceTypeList,
                      help="Observed service types indexed by interface name" )

   def render( self ):
      print( "%-16s %-16s %-30s %s" %
             ( 'Interface', 'Address Family', 'Service Type', 'Elapsed Time' ) )
      for intf in Arnet.sortIntf( self.serviceTypes ):
         entry = self.serviceTypes[ intf ]
         for entry in sorted( entry.intfInfos, key=lambda x: x.family ):
            print( '%-16s %-16s %-30s %s' % ( intf, entry.family, entry.serviceType,
                  Ark.timestampToStr( entry.timestamp ) ) )

class MdnsRecordInfo( Model ):
   interface = Interface( help='Interface the service record is learned from',
               optional=True )
   link = Str( help='Link name the service record is learned from', optional=True )

   locations = List( valueType=str, help='Location tags of the service record' )
   conflicts = Bool( help='Existence of conflicting records in the network',
                     optional=True )

class MdnsRecordInfoList( Model ):
   recordInfoList = List( valueType=MdnsRecordInfo, help='MDNS records' )

class MdnsRecordEntry( Model ):
   # Store services from different links
   recordName = Str( help='MDNS record name lower cased' )
   recordType = Enum( values=Tac.Type( 'McastDns::DnsRrType' ).attributes,
                      help='MDNS record type' )
   recordClass = Int( help='MDNS record class' )
   recordInfos = Dict( keyType=str, valueType=MdnsRecordInfoList,
         help="A mapping of mDNS original record name to its information" )

class MdnsRecordEntryList( Model ):
   recordEntryList = List( valueType=MdnsRecordEntry, help='MDNS records' )

class MdnsRecords( Model ):
   records = Dict( keyType=str, valueType=MdnsRecordEntryList,
         help="A mapping of MDNS lower cased record name to record entries" )

   def render( self ):
      headingEntries = [
            'Name', 'Type', 'Class', 'Interface/Link', 'Location', 'Status' ]
      table = TableOutput.createTable( headingEntries, indent=0 )
      columnFormat = TableOutput.Format( justify='left', maxWidth=50, wrap=True )
      table.formatColumns( *( [ columnFormat ] * len( headingEntries ) ) )
      for nameKey in sorted( self.records ):
         for record in sorted( self.records[ nameKey ].recordEntryList,
               key=lambda x: ( x.recordType, x.recordClass ) ):
            recordInfos = record.recordInfos
            for key in Arnet.sortIntf( recordInfos ):
               recordInfoList = recordInfos[ key ].recordInfoList
               # As show mdns service rule table, records are sorted by interfaces
               # first ( records without interface are sorted later ), and then by
               # links ( records without link name are sorted earlier )
               for recordInfo in sorted( recordInfoList, key=lambda x: (
                     Arnet.intfNameKey( x.interface.stringValue )
                     if x.interface else [ '\xff' ],
                     Arnet.intfNameKey( x.link ) if x.link else [ '\x00' ] ) ):
                  intfStr = recordInfo.interface.stringValue \
                        if recordInfo.interface else ''
                  linkStr = recordInfo.link if recordInfo.link else ''
                  intfLinkStr = '/'.join( _f for _f in ( intfStr, linkStr ) if _f )
                  conflictsStr = 'conflicts' if recordInfo.conflicts else ''
                  for location in Arnet.sortIntf( recordInfo.locations ):
                     table.newRow( key, record.recordType.upper(),
                           record.recordClass, intfLinkStr, location, conflictsStr )
      print( table.output() )
