# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import struct
from Arnet import Ip6Addr
import BasicCli
import CliCommand
import CliMatcher
import ConfigMount
import LazyMount
import Tac
from Tracing import Handle
from Intf.IntfRange import IntfRangeMatcher
import CliParser
from CliPlugin.EthIntfCli import EthPhyAutoIntfType
from CliPlugin.VlanIntfCli import VlanAutoIntfType
from CliPlugin.IpAddrMatcher import IpAddrMatcher
from CliPlugin.McastDnsModel import MdnsStatus, MdnsServerGatewayStatus, \
      MdnsLink, MdnsLinkList, MdnsLinkEntry, MdnsServiceRecordServiceRule, \
      MdnsServiceList, MdnsServiceEntry, MdnsServiceRecordServiceName, \
      MdnsServiceInfo, MdnsServiceHostInfo, MdnsUnknownList, Record, \
      MdnsServiceType, MdnsServiceTypeList, MdnsServiceTypeEntry, \
      MdnsCounter, MdnsRecords, MdnsRecordEntryList, MdnsRecordEntry, \
      MdnsRecordInfoList, MdnsRecordInfo
from CliMode.McastDns import McastDnsMode, McastDnsServiceMode
from ArnetModel import IpGenericAddrAndPort
import ShowCommand
from Toggles.McastDnsToggleLib import toggleMcastDnsFloodSuppressionEnabled

# To be able to mount mdns/config when loading plugin:
# force dependency to make sure that McastDns-lib will be installed
# before McastDns-cli. See AID10.
# pkgdeps: rpmwith %{_libdir}/preinit/McastDns

__defaultTraceHandle__ = Handle( 'McastDnsCli' )
mcastDnsConfig = None
mcastDnsStatus = None
hwCapability = None

intfTypes = ( EthPhyAutoIntfType, VlanAutoIntfType )
intfRangeMatcher = IntfRangeMatcher( explicitIntfTypes=intfTypes,
                                     helpdesc='List of interfaces' )
matcherPort = CliMatcher.IntegerMatcher( 1, 65535,
      helpdesc='Configure TCP port to this value' )
addRemoveMatcher = CliMatcher.EnumMatcher( { 'add': 'Add to list',
                                    'remove': 'Remove from list', } )
mdnsMatcher = CliMatcher.KeywordMatcher( 'mdns',
      helpdesc='Multicast DNS Gateway information' )
anyMatcher = CliMatcher.KeywordMatcher( 'any',
      helpdesc='Any service types allowed by the service rule' )
detailMatcher = CliMatcher.KeywordMatcher( 'detail',
      helpdesc='Display information in detail' )

IntfId = Tac.Type( 'Arnet::IntfId' )
IntfKey = Tac.Type( 'McastDns::IntfKey' )
AddressFamily = Tac.Type( 'Arnet::AddressFamily' )
LinkId = Tac.Type( 'McastDns::LinkId' )
RrLinkId = Tac.Type( 'McastDns::RrLinkId' )
IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )
Port = Tac.Type( 'Arnet::Port' )
DefaultPort = Tac.Value( 'McastDns::DefaultPort' )
LinkName = Tac.Type( "McastDns::LinkName" )
LocationTag = Tac.Type( "McastDns::LocationTag" )
DnsName = Tac.Type( "McastDns::DnsName" )
DnsNameLower = Tac.Type( "McastDns::DnsNameLower" )
DnsRrType = Tac.Type( "McastDns::DnsRrType" )
DnsRrClass = Tac.Type( "McastDns::DnsRrClass" )
RrSetKey = Tac.Type( "McastDns::RrSetKey" )

def intfIdsAdapter( mode, args, argsList ):
   if intfs := args.get( 'INTFS' ):
      args[ 'INTFS' ] = [ IntfId( i ) for i in intfs ]

def validateServiceType( mode, argsTypes ):
   import re # pylint: disable=import-outside-toplevel
   # RFC6335 Section 5.1: Service name syntax requirements
   # Although RFC6335 states that service names can contain uppercase letters,
   # we restrict service types to be lowercase only, since "case is ignored for
   # comparison purposes" and this avoids extra agent logic that would otherwise
   # be needed for it to know that, e.g., _ftp._tcp. and _FTP._tcp., are equivalent.
   #
   # RFC6763 Section 7.1: Subtypes
   typePattern = re.compile( r'(.+\._sub\.)?_[a-z0-9-]+\._[a-z0-9-]+' )
   letterPattern = re.compile( r'[a-z]' )
   invalidTypes = []
   warningStr = ''
   for t in argsTypes:
      if t == 'any':
         continue

      if t == '_dns-sd._udp.':
         invalidTypes.append( t )
         warningStr = '. _dns-sd._udp. is not supported'
         continue

      if not typePattern.search( t ):
         invalidTypes.append( t )
         continue

      if not t.endswith( '.' ):
         invalidTypes.append( t )
         warningStr = ' . Service type must end with a dot (.)'
         continue

      if 'sub._' in t:
         t = t.split( 'sub._', 1 )[ 1 ]

      # For '_ipp._tcp.', serviceName is 'ipp'
      serviceName = t.split( '.', 1 )[ 0 ][ 1 : ]

      if not 1 <= len( serviceName ) < 16: # pylint: disable=no-else-continue
         invalidTypes.append( t )
         warningStr = '. Service name must be 1-15 characters long'
         continue
      elif not letterPattern.search( serviceName ):
         invalidTypes.append( t )
         warningStr = '. Service name must contain at least one letter'
         continue
      elif serviceName.startswith( '-' ):
         invalidTypes.append( t )
         warningStr = '. Service name must not begin with a hyphen'
         continue
      elif serviceName.endswith( '-' ):
         invalidTypes.append( t )
         warningStr = '. Service name must not end with a hyphen'
         continue
      elif '--' in serviceName:
         invalidTypes.append( t )
         warningStr = '. Hyphens must not be adjacent to each other'
         continue

   invalidTypesStr = ' '.join( invalidTypes )
   if len( invalidTypes ) == 1:
      # pylint: disable-next=consider-using-f-string
      mode.addWarning( "Ignore invalid type: %s%s" %
                       ( invalidTypesStr, warningStr ) )
   elif len( invalidTypes ) > 1:
      # pylint: disable-next=consider-using-f-string
      mode.addWarning( "Ignore invalid types: %s" % invalidTypesStr )

   return set( argsTypes ) - set( invalidTypes )

def addRemoveColl( addOrRemove, coll, argsList ):
   if addOrRemove == 'add':
      for arg in argsList:
         coll.add( arg )
   else:
      for arg in argsList:
         del coll[ arg ]

# This CLI plugin defines the following configuration commands:
#
# mdns
#    [ no | default ] disabled
#    [ no | default ] flooding suppression
#
#    remote-gateway ipv4 <IPADDR> [ tcp-port <PORT> ]
#    dso server ipv4 [ tcp-port <PORT> ]
#
#    service <NAME>
#       type [ ACTION ] { TYPES }
#       query [ ACTION ] <INTFS>
#       response interface [ ACTION ] <INTFS>
#       response link [ ACTION ] <LINKS>
#       match ( by-tag | group [ ACTION ] <LOCATIONTAGS> )

#------------------------------------------------------------------------------
# The "config-mdns" mode.
#------------------------------------------------------------------------------
class McastDnsConfigMode( McastDnsMode, BasicCli.ConfigModeBase ):
   name = 'mdns configuration'

   def __init__( self, parent, session ):
      McastDnsMode.__init__( self )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

#------------------------------------------------------------------------------
# (config)# mdns
#------------------------------------------------------------------------------
class EnterMcastDnsConfigModeCmd( CliCommand.CliCommandClass ):
   syntax = 'mdns'
   noOrDefaultSyntax = syntax
   data = {
      'mdns': 'Multicast DNS Gateway configuration',
   }

   @staticmethod
   def handler( mode, args ):
      childMode = mode.childMode( McastDnsConfigMode )
      mode.session_.gotoChildMode( childMode )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      # mcastDnsConfig.enabled is set to false in resetMdnsConfig()
      mcastDnsConfig.resetMdnsConfig()

BasicCli.GlobalConfigMode.addCommandClass( EnterMcastDnsConfigModeCmd )

#------------------------------------------------------------------------------
# (config-mdns)# [ no | default ] disabled
#------------------------------------------------------------------------------
class McastDnsDisabledCmd( CliCommand.CliCommandClass ):
   syntax = 'disabled'
   noOrDefaultSyntax = syntax
   data = {
      'disabled': 'Disable Multicast DNS Gateway',
   }

   @staticmethod
   def handler( mode, args ):
      mcastDnsConfig.enabled = False

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mcastDnsConfig.enabled = True

McastDnsConfigMode.addCommandClass( McastDnsDisabledCmd )

#------------------------------------------------------------------------------
# (config-mdns)# [ no | default ] flooding suppression
#------------------------------------------------------------------------------

def mdnsFloodSuppressionGuard( mode, token ):
   if hwCapability.mdnsFloodingSuppressionSupported:
      return None
   return CliParser.guardNotThisPlatform

class McastDnsFloodSuppressionCmd( CliCommand.CliCommandClass ):
   syntax = 'flooding suppression'
   noOrDefaultSyntax = syntax
   data = {
      'flooding': 'Multicast DNS flooding',
      'suppression': CliCommand.guardedKeyword( 'suppression',
         helpdesc='Disable Multicast DNS flooding',
         guard=mdnsFloodSuppressionGuard )
   }

   @staticmethod
   def handler( mode, args ):
      mcastDnsConfig.floodSuppression = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mcastDnsConfig.floodSuppression = False

if toggleMcastDnsFloodSuppressionEnabled():
   McastDnsConfigMode.addCommandClass( McastDnsFloodSuppressionCmd )

#------------------------------------------------------------------------------
# (config-mdns)# remote-gateway ipv4 <IPADDR> [ tcp-port <PORT> ]
#------------------------------------------------------------------------------
class McastDnsRemoteGatewayCmd( CliCommand.CliCommandClass ):
   syntax = 'remote-gateway ipv4 IPADDR [ tcp-port PORT ]'
   noOrDefaultSyntax = 'remote-gateway ipv4 [ IPADDR ] ...'
   data = {
      'remote-gateway': 'Specify a remote Multicast DNS gateway',
      'ipv4': 'Specify IPv4 remote gateway',
      'IPADDR': IpAddrMatcher( helpdesc='Remote gateway IP address' ),
      'tcp-port': 'Configure TCP port on remote gateway to connect to',
      'PORT': matcherPort,
   }

   @staticmethod
   def handler( mode, args ):
      ipAddr = IpGenAddr( args[ 'IPADDR' ] )
      gw = mcastDnsConfig.newRemoteGateway( ipAddr )
      port = Port( args.get( 'PORT', DefaultPort.dsoPort ) )
      gw.port = port

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      ipAddr = args.get( 'IPADDR' )
      if ipAddr:
         del mcastDnsConfig.remoteGateway[ IpGenAddr( ipAddr ) ]
      else:
         mcastDnsConfig.remoteGateway.clear()

McastDnsConfigMode.addCommandClass( McastDnsRemoteGatewayCmd )

#------------------------------------------------------------------------------
# (config-mdns)# dso server ipv4 [ tcp-port <PORT> ]
#------------------------------------------------------------------------------
class McastDnsDsoServerCmd( CliCommand.CliCommandClass ):
   syntax = 'dso server ipv4 [ tcp-port PORT ]'
   noOrDefaultSyntax = 'dso server ipv4 ...'
   data = {
      'dso': 'Accept DSO over a TCP connection',
      'server': 'Configure DSO server to listen on a TCP port',
      'ipv4': 'Listen on an IPv4 address',
      'tcp-port': 'Configure TCP port that DSO server uses',
      'PORT': matcherPort,
   }

   @staticmethod
   def handler( mode, args ):
      port = Port( args.get( 'PORT', DefaultPort.dsoPort ) )
      mcastDnsConfig.serverPort = port

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mcastDnsConfig.serverPort = Port()

McastDnsConfigMode.addCommandClass( McastDnsDsoServerCmd )

#------------------------------------------------------------------------------
# The "config-mdns-service" mode.
#------------------------------------------------------------------------------
class McastDnsServiceConfigMode( McastDnsServiceMode, BasicCli.ConfigModeBase ):
   name = 'mdns service configuration'

   def __init__( self, parent, session, serviceName ):
      self.serviceName = serviceName
      McastDnsServiceMode.__init__( self, self.serviceName )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

   def getServiceName( self ):
      return self.serviceName

#------------------------------------------------------------------------------
# (config-mdns-service)# service <NAME>
#------------------------------------------------------------------------------
class EnterMcastDnsServiceConfigCmd( CliCommand.CliCommandClass ):
   syntax = 'service NAME'
   noOrDefaultSyntax = syntax
   data = {
      'service': 'Specify a service group',
      'NAME': CliMatcher.PatternMatcher( '[a-zA-Z0-9_-]+', helpname='WORD',
                                         helpdesc='Name of service group' ),
   }

   @staticmethod
   def handler( mode, args ):
      name = args[ 'NAME' ]
      mcastDnsConfig.newServiceRule( name )
      childMode = mode.childMode( McastDnsServiceConfigMode, serviceName=name )
      mode.session_.gotoChildMode( childMode )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      del mcastDnsConfig.serviceRule[ args[ 'NAME' ] ]

McastDnsConfigMode.addCommandClass( EnterMcastDnsServiceConfigCmd )

#------------------------------------------------------------------------------
# (config-mdns-service)# type  [ ACTION ] <TYPES>
#------------------------------------------------------------------------------
class McastDnsServiceTypeCmd( CliCommand.CliCommandClass ):
   syntax = 'type [ ACTION ] { TYPES }'
   noOrDefaultSyntax = 'type ...'
   data = {
      'type': 'Specify supported service type(s)',
      'ACTION': addRemoveMatcher,
      'TYPES': CliMatcher.PatternMatcher( r'^(?!add$|remove$).+',
         helpname='WORD',
         helpdesc='List of service types allowed by the service rule.'
               ' The "any" service type allows all service types' ),
   }

   @staticmethod
   def handler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      srvArg = args.get( 'TYPES' )
      validTypes = validateServiceType( mode, srvArg )
      if not validTypes:
         return
      addOrRemove = args.get( 'ACTION' )
      if not addOrRemove:
         # Replace existing service types.
         service.serviceType.clear()
         addOrRemove = 'add'
      addRemoveColl( addOrRemove, service.serviceType, validTypes )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      service.serviceType.clear()

McastDnsServiceConfigMode.addCommandClass( McastDnsServiceTypeCmd )

#------------------------------------------------------------------------------
# (config-mdns-service)# query [ ACTION ] <INTFS>
#------------------------------------------------------------------------------
class McastDnsServiceQueryCmd( CliCommand.CliCommandClass ):
   syntax = 'query [ ACTION ] INTFS'
   noOrDefaultSyntax = 'query ...'
   data = {
      'query': 'Specify link(s) where queries are accepted',
      'ACTION': addRemoveMatcher,
      'INTFS': intfRangeMatcher,
   }
   adapter = intfIdsAdapter

   @staticmethod
   def handler( mode, args ):
      intfIds = args[ 'INTFS' ]
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      addOrRemove = args.get( 'ACTION' )
      if not addOrRemove:
         # Replace existing query links.
         service.queryLink.clear()
         addOrRemove = 'add'
      addRemoveColl( addOrRemove, service.queryLink, intfIds )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      service.queryLink.clear()

McastDnsServiceConfigMode.addCommandClass( McastDnsServiceQueryCmd )

#------------------------------------------------------------------------------
# (config-mdns-service)# response interface [ ACTION ] <INTFS>
#------------------------------------------------------------------------------
class McastDnsServiceResponseInterfaceCmd( CliCommand.CliCommandClass ):
   syntax = 'response interface [ ACTION ] INTFS'
   noOrDefaultSyntax = 'response interface ...'
   data = {
      'response': 'Specify link(s) where service announcements are accepted',
      'interface': 'Specify local interface(s)',
      'ACTION': addRemoveMatcher,
      'INTFS': intfRangeMatcher,
   }
   adapter = intfIdsAdapter

   @staticmethod
   def handler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      addOrRemove = args.get( 'ACTION' )
      if not addOrRemove:
         # Replace existing response interfaces.
         for link in service.responseLink:
            if link.intfId:
               del service.responseLink[ link ]
         addOrRemove = 'add'

      intfIds = args.get( 'INTFS' )
      if intfIds:
         intfLinkIds = [ LinkId( AddressFamily.ipv4, '', i ) for i in intfIds ]
         addRemoveColl( addOrRemove, service.responseLink, intfLinkIds )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      # Delete only interfaces from responseLink
      for link in service.responseLink:
         if link.intfId:
            del service.responseLink[ link ]

McastDnsServiceConfigMode.addCommandClass( McastDnsServiceResponseInterfaceCmd )

#------------------------------------------------------------------------------
# (config-mdns-service)# response link [ ACTION ] <LINKS>
#------------------------------------------------------------------------------
class McastDnsServiceResponseLinkCmd( CliCommand.CliCommandClass ):
   syntax = 'response link [ ACTION ] { LINKS }'
   noOrDefaultSyntax = 'response link ...'
   data = {
      'response': 'Specify link(s) where service announcements are accepted',
      'link': 'Specify names of these links',
      'ACTION': addRemoveMatcher,
      'LINKS': CliMatcher.DynamicNameMatcher(
         lambda mode: ( l.linkName for l in mcastDnsConfig.link.values() ),
         'Link name' )
   }

   @staticmethod
   def handler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      links = args.get( 'LINKS', [] )
      nameLinkIds = []
      for linkName in links:
         if len( linkName ) > LinkName.maxLength:
            # pylint: disable-next=consider-using-f-string
            mode.addError( "'%s' too long: must be no more than %d characters"
                           % ( linkName, LinkName.maxLength ) )
            return
         nameLinkIds.append( LinkId( AddressFamily.ipv4, linkName, IntfId() ) )

      addOrRemove = args.get( 'ACTION' )
      if not addOrRemove:
         # Replace existing response links.
         for link in service.responseLink:
            if link.name:
               del service.responseLink[ link ]
         addOrRemove = 'add'

      if nameLinkIds:
         addRemoveColl( addOrRemove, service.responseLink, nameLinkIds )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      for link in service.responseLink:
         if link.name:
            del service.responseLink[ link ]

McastDnsServiceConfigMode.addCommandClass( McastDnsServiceResponseLinkCmd )

#------------------------------------------------------------------------------
# (config-mdns-service)# match ( by-tag | ( group [ ACTION ] <LOCATIONTAGS> ) )
#------------------------------------------------------------------------------
class McastDnsServiceMatchCmd( CliCommand.CliCommandClass ):
   syntax = 'match ( by-tag | ( group [ ACTION ] { LOCATIONTAGS } ) )'
   noOrDefaultSyntax = 'match [ by-tag | group ]...'
   data = {
      'match': 'Filter results from a query',
      'by-tag': ( 'Limit query results to records whose location tag matches the '
                  'query\'s location' ),
      'group': ( 'Limit query results to records whose location tag matches the '
                 'specified filter(s)' ),
      'ACTION': addRemoveMatcher,
      'LOCATIONTAGS': CliMatcher.PatternMatcher( r'^(?!add$|remove$).+',
                                          helpname='WORD',
                                          helpdesc='List of filters' ),
   }

   @staticmethod
   def handler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      # Replace existing match config.
      if 'by-tag' in args:
         service.matchGroup.clear()
         service.matchByTag = True
      elif 'group' in args:
         service.matchByTag = False
         for tag in args[ 'LOCATIONTAGS' ]:
            if len( tag ) > LocationTag.maxLength:
               # pylint: disable-next=consider-using-f-string
               mode.addError( "The location tag must not be more than %d characters"
                              % ( LocationTag.maxLength ) )
               return
            if any( not c.isalpha() and not c.isdigit() and not c == '-'
               for c in tag ):
               mode.addError( "The location tag must only contain letters, digits " +
                     "or hyphens" )
               return
         addOrRemove = args.get( 'ACTION' )
         if not addOrRemove:
            service.matchGroup.clear()
            addOrRemove = 'add'
         addRemoveColl( addOrRemove, service.matchGroup, args[ 'LOCATIONTAGS' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      service = mcastDnsConfig.serviceRule[ mode.getServiceName() ]
      if 'by-tag' not in args:
         service.matchGroup.clear()
      if 'group' not in args:
         service.matchByTag = False

McastDnsServiceConfigMode.addCommandClass( McastDnsServiceMatchCmd )

#-------------------------------------------------------------------------------
# "show mdns counters"
#-------------------------------------------------------------------------------
class ShowMdnsCounter( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns counters'
   data = {
      'mdns': mdnsMatcher,
      'counters': 'Show counters',
   }
   cliModel = MdnsCounter

   @staticmethod
   def handler( mode, args ):
      counterModel = MdnsCounter()
      counterEnum = Tac.typeNode( 'McastDns::CounterId::CounterIdEnum' )
      emptyCounter = Tac.newInstance( 'McastDns::Counter', 0 )
      for attr in counterEnum.attributeQ:
         enumValue = attr.enumValue
         # See if there is are any counter for this enum value
         count = mcastDnsStatus.counter.get( enumValue, emptyCounter )
         # set the attribute in the model with the count
         setattr( counterModel, attr.name, count.counterValue )

      return counterModel

BasicCli.addShowCommandClass( ShowMdnsCounter )

#-------------------------------------------------------------------------------
# "show mdns status"
#-------------------------------------------------------------------------------
class ShowMdnsStatus( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns status'
   data = {
      'mdns': mdnsMatcher,
      'status': 'Show status',
   }
   cliModel = MdnsStatus

   @staticmethod
   def handler( mode, args ):
      status = MdnsStatus()
      enabled = mcastDnsConfig.enabled
      running = mcastDnsStatus.running
      status.mdnsEnabled = enabled
      status.mdnsRunning = running
      if toggleMcastDnsFloodSuppressionEnabled():
         floodSuppression = mcastDnsStatus.floodSuppression
         status.floodSuppression = floodSuppression
      status.dsoEnabled = bool( enabled and mcastDnsConfig.serverPort )
      status.dsoRunning = status.dsoEnabled and running
      if not running:
         return status
      for addr, gwStatus in mcastDnsStatus.tcpClient.items():
         client = IpGenericAddrAndPort( ip=addr, port=gwStatus.port )
         status.tcpClient.append( MdnsServerGatewayStatus(
            client=client, connectionStatus=gwStatus.connectionStatus ) )
      for gw in mcastDnsStatus.tcpServerClient:
         client = IpGenericAddrAndPort( ip=gw.address, port=gw.port )
         status.tcpServerClient.append( client )

      return status

BasicCli.addShowCommandClass( ShowMdnsStatus )

#-------------------------------------------------------------------------------
# "show mdns links"
#-------------------------------------------------------------------------------
class ShowMdnsLinks( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns links'
   data = {
      'mdns': mdnsMatcher,
      'links': 'Show status of mDNS enabled links',
   }
   cliModel = MdnsLink

   @staticmethod
   def handler( mode, args ):
      link = MdnsLink()
      if not mcastDnsStatus.running:
         mode.addWarning( 'McastDns is not running' )
         return link
      linkStatus = mcastDnsStatus.linkStatus
      for key, value in mcastDnsConfig.link.items():
         entry = MdnsLinkEntry( family=key.family, linkId=value.linkName,
                                defaultTag=value.defaultTag )
         intfKey = IntfKey( key.intfId, key.family )
         status = linkStatus.get( intfKey )
         if status and status.state == 'connected':
            entry.status = 'active'
         else:
            entry.status = 'inactive'
         linkList = link.interfaces.setdefault( key.intfId, MdnsLinkList() )
         linkList.intfInfos.append( entry )
      return link

BasicCli.addShowCommandClass( ShowMdnsLinks )

# Get allowed Intf and Link per rrLink and per srvRule
def getAllowedIntfLink( rrLink, srvRule ):
   linkIdIntf = RrLinkId( rrLink.family, LinkName(), rrLink.intfId )
   linkIdName = RrLinkId( rrLink.family, rrLink.name, IntfId() )
   interface = None
   link = None
   if srvRule.responseLinkHas( linkIdIntf ):
      interface = linkIdIntf.intfId
   if srvRule.responseLinkHas( linkIdName ):
      link = linkIdName.name
   return interface, link

# Go through all the service rules per rrLink
# It is possible that Interface is configured in one service rule and Link
# is configured in another service rule, so keep looking until either both
# Interface and Link are found or the loop ends
def getAllowedIntfLinkPerRrLink( rrLink ):
   interface = None
   link = None
   for srvRule in mcastDnsConfig.serviceRule.values():
      acceptedIntf, acceptedLink = getAllowedIntfLink( rrLink, srvRule )
      interface = acceptedIntf if acceptedIntf else interface
      link = acceptedLink if acceptedLink else link
      if interface and link:
         return interface, link
   return interface, link

# Go through all the rrLinks of RrEntry per service rule
# Get allowed rrLink to ( interface, link ) map
# It is possible that only one of the interface and linkName is allowed, or they are
# allowed in different service rules.
def getAllowedLinksPerSrvRule( rrSetKey, srvRule ):
   rrEntry = mcastDnsStatus.database.record.get( rrSetKey )
   allowedLinks = {}
   if not rrEntry:
      return allowedLinks
   for rrLink in rrEntry.rrLinkData:
      interface, link = getAllowedIntfLink( rrLink, srvRule )
      if interface or link:
         allowedLinks[ rrLink ] = ( interface, link )

   return allowedLinks

# Records conflicts for a certain RrSetKey under a given service rule as long as:
# 1. More then one records exist accross allowed links under the service rule
# 2. At least one record is unique
def isRrSetKeyConflictingForServiceRule( rrSetKey, allowedLinks ):
   rrEntry = mcastDnsStatus.database.record.get( rrSetKey )
   if not rrEntry:
      return False
   rrLinkData = rrEntry.rrLinkData
   # Check the uniqueness requirement of rrSetKey
   # Any unique flag among the allowedLinks will make the record require uniqueness
   count = 0
   unique = False
   for rrLink in allowedLinks:
      count += len( rrLinkData[ rrLink ].data )
      if any( rrData.unique for rrData in rrLinkData[ rrLink ].data.values() ):
         unique = True
      if unique and count > 1:
         return True
   return False

# For the given RrSetKey, go through all the service rules and record each
# service rule a conflict flag
def isRrSetKeyConflicting( rrSetKey, interface=None, link=None ):
   for srvRule in mcastDnsConfig.serviceRule.values():
      allowedLinks = getAllowedLinksPerSrvRule( rrSetKey, srvRule )
      # When interface or link is specified, only look at service rules that contains
      # either of them as the response link.
      srvRuleValid = False
      if interface or link:
         for allowedIntfLink in allowedLinks.values():
            if ( interface and interface in allowedIntfLink ) or \
                  ( link and link in allowedIntfLink ):
               srvRuleValid = True
               break
         if not srvRuleValid:
            continue
      if isRrSetKeyConflictingForServiceRule( rrSetKey, allowedLinks ):
         return True
   return False

# Detect conflicting for a certain dnsName.
# Example: printer._ipp._tcp.local. has SRV and TXT records. Both of the types can
# have conflicts when flood suppression is enabled.
def isRecordNameConflicting( dname, interface, link ):
   dnameEntry = mcastDnsStatus.database.domainName.get( dname.lower() )
   if not dnameEntry:
      return False
   rrTypeSet = set()
   for rrLinkType in dnameEntry.rrLinkType.values():
      rrTypeSet.update( rrLinkType.rrLinkData )
   for rrtype in rrTypeSet:
      rrSetKey = RrSetKey( dname, rrtype, 1 )
      if isRrSetKeyConflicting( rrSetKey.lower(), interface=interface, link=link ):
         return True
   return False

#-------------------------------------------------------------------------------
# "show mdns service rule <NAME> [ type <TYPE> ]"
#-------------------------------------------------------------------------------
class ShowMdnsServiceRule( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns service rule NAME [ type TYPE ]'
   data = {
      'mdns': mdnsMatcher,
      'service': 'Multicast DNS services',
      'rule': 'Multicast DNS service rule',
      'NAME': CliMatcher.DynamicNameMatcher(
         lambda mode: mcastDnsConfig.serviceRule,
         'Service rule name' ),
      'type': 'Multicast DNS service types',
      'TYPE': CliMatcher.PatternMatcher( CliParser.namePattern, helpname='WORD',
                                         helpdesc='Multicast DNS service type name' )
   }
   cliModel = MdnsServiceRecordServiceRule

   @staticmethod
   def handler( mode, args ):
      if not mcastDnsStatus.running:
         mode.addWarning( 'McastDns is not running' )
         return MdnsServiceRecordServiceRule()
      reqSrvRuleName = args.get( 'NAME' )
      reqSrvType = args.get( 'TYPE' )
      srvRule = mcastDnsConfig.serviceRule.get( reqSrvRuleName )
      srvRecords = MdnsServiceRecordServiceRule()
      if not srvRule:
         return srvRecords

      # Collect query intfIds
      queIntfs = list( srvRule.queryLink )
      # Collect response intfIds or linknames
      responseLinks = []
      responseIntfs = []
      resIds = srvRule.responseLink
      for resId in resIds:
         if resId.name:
            responseLinks.append( resId.name )
         else:
            responseIntfs.append( resId.intfId )
      srvRecords.queryLinks = queIntfs
      srvRecords.responseLinks = responseLinks
      srvRecords.responseIntfs = responseIntfs

      # Collect service types
      conTypes = srvRule.serviceType
      hasAny = 'any' in conTypes
      interestedTypes = { conType
                    for conType in conTypes
                    if conType != 'any' }

      # has service type 'any', should also look at observed service types
      if hasAny:
         allIntfs = set( queIntfs + responseIntfs )
         for key, value in mcastDnsConfig.link.items():
            if key.intfId in allIntfs or value.linkName in responseLinks:
               intfKey = IntfKey( key.intfId, key.family )
               value = mcastDnsStatus.observedServiceType.get( intfKey )
               if value:
                  interestedTypes.update( value.serviceType )

      # CIS: Compare service type case-insensitively
      srvTypes = [ srvType.lower() for srvType in interestedTypes ]
      srvRecords.srvTypes = srvTypes
      if reqSrvType:
         reqSrvType = reqSrvType.lower()
         if reqSrvType not in srvTypes:
            srvRecords.srvTypes = list( interestedTypes )
            return srvRecords
         srvTypes = [ reqSrvType ]

      services = {}
      for rrSetKey, rrEntry in mcastDnsStatus.database.record.items():
         if rrSetKey.type != DnsRrType.srv:
            continue
         dname = rrSetKey.dname
         # dname is of type DnsNameLower and has lower cased service type
         if dname.serviceType not in srvTypes:
            continue

         allowedLinks = getAllowedLinksPerSrvRule( rrSetKey, srvRule )
         if not allowedLinks:
            continue

         if toggleMcastDnsFloodSuppressionEnabled():
            # Determine if a rrSetKey is conflicting under the current service rule
            conflicts = \
                  isRrSetKeyConflictingForServiceRule( rrSetKey, allowedLinks )
         # Use a flag to ensure only one record is populated when conflicting
         populated = False

         # Collect services for the same lower cased names and different links.
         for rrLink, ( interface, link ) in allowedLinks.items():
            # CIS: Save different srvEntry for the current link because each rrdata
            # may have different cased origRrSetKey
            srvEntryDict = {}
            for rrData in rrEntry.rrLinkData.get( rrLink ).data.values():
               if toggleMcastDnsFloodSuppressionEnabled():
                  if conflicts and populated:
                     break
               origName = rrData.origRrSetKey.dname.name
               # For the same origName in the same link, there is only one entry
               services.setdefault( origName, MdnsServiceList() )
               if toggleMcastDnsFloodSuppressionEnabled():
                  srvEntryDict.setdefault( origName, MdnsServiceEntry(
                     interface=interface, link=link, conflicts=conflicts ) )
               else:
                  srvEntryDict.setdefault( origName, MdnsServiceEntry(
                     interface=interface, link=link ) )
               srvEntryDict[ origName ].locations.append( rrData.location )
               populated = True

            for origName, srvEntry in srvEntryDict.items():
               service = services.get( origName )
               if service:
                  service.srvInfos.append( srvEntry )

      srvRecords.services = services
      return srvRecords

BasicCli.addShowCommandClass( ShowMdnsServiceRule )

# Use a relaxed methed to find the service type and let the look up to determine
# if an entry with an RFC compliant service record is in the database
# https://tools.ietf.org/html/rfc6763#section-4.1.1
# Service Instance Name is a user-friendly name consisting of arbitrary
# Net-Unicode text [RFC5198]. It may include dots.
# In printer.floor1._printer._sub._ipp._tcp.local, printer1.floor1 is considered
# as instance name.
def getServiceType( dname ):
   # Remove the last dot, allow users to not putting a dot at the end.
   dname = dname[ : -1 ] if dname[ -1 ] == '.' else dname
   labels = dname.split( '.' )
   if len( labels ) <= 3:
      return ''
   if labels[ -1 ].lower() != 'local':
      return ''

   # dname without sub-type
   if labels[ -4 ] != '_sub':
      return '.'.join( labels[ -3 : -1 ] ) + '.'
   # dname with sub-type
   if len( labels ) <= 5:
      return ''
   return '.'.join( labels[ -5 : -1 ] ) + '.'

# Convert SRV rdata first 6 byte to priority, weight, and port
# E.g., '\x00\x16\x00\x11\x1f\x90' => priority=22, weight=17, port=8080
def getSrvMeta( data ):
   if len( data ) < 6:
      return None, None, None
   return struct.unpack( '>HHH', data[ : 6 ] )

# Conver e.g., '\n\x00\x00\x02' to '10.0.0.2'
def getV4Addr( data ):
   if len( data ) != 4:
      return ''
   addrList = ( str( e ) for e in iter( data ) )
   addrStr = '.'.join( addrList )
   return addrStr

# Conver e.g., '\xfe\x80\x00\x00\x00\x00\x00\x00\x04\xb8\x9d\x9b\x08O\xdc\xab' to
# 'fe80::4b8:9d9b:84f:dcab'
def getV6Addr( data ):
   if len( data ) != 16:
      return ''
   addrTuple = struct.unpack( '>HHHHHHHH', data )
   addrStr = ':'.join( str( hex( e ) )[ 2 : ] for e in addrTuple )
   addr = Ip6Addr( addrStr )
   return addr.stringValue

def showMdnsServiceNameHandler( mode, args ):
   reqName = args.get( 'NAME' )
   detail = 'detail' in args

   # pylint: disable-msg=protected-access
   record = MdnsServiceRecordServiceName()
   record._detail = detail

   if not mcastDnsStatus.running:
      mode.addWarning( 'McastDns is not running' )
      return record

   database = mcastDnsStatus.database
   if not database:
      # database is none before mdns is configured
      return record

   srvType = getServiceType( reqName )
   dname = DnsName( reqName )
   dname.serviceType = srvType
   dnameEntry = database.domainName.get( dname.lower() )
   if not dnameEntry:
      return record

   srvRrSetKey = RrSetKey( dname, 33, 1 )
   rrEntry = database.record.get( srvRrSetKey.lower() )
   if not rrEntry:
      return record

   record.name = reqName
   if toggleMcastDnsFloodSuppressionEnabled():
      record.conflicts = isRrSetKeyConflicting( srvRrSetKey.lower() )

   # pylint: disable=too-many-nested-blocks
   for rrLink, rrLinkType in dnameEntry.rrLinkType.items():
      # Skip if not an SRV record
      if not rrLinkType.rrLinkData.get( DnsRrType.srv ):
         continue

      # Adding link info
      interface, link = getAllowedIntfLinkPerRrLink( rrLink )
      if not interface and not link:
         continue

      srvInfoEntry = MdnsServiceInfo( interface=interface, link=link )

      if detail:
         # Add SRV-TXT record if exists
         rrLinkDataTxt = rrLinkType.rrLinkData.get( DnsRrType.txt )
         if rrLinkDataTxt:
            for rdata, rrdata in rrLinkDataTxt.data.items():
               if reqName != rrdata.origRrSetKey.dname.name:
                  continue
               dataBytes = list( iter( rdata.data ) )
               srvInfoEntry.srvTxts.append( Record( recordBytes=dataBytes ) )

      # Adding host info
      rrdata = rrLinkType.rrLinkData[ DnsRrType.srv ].data
      for data in rrdata.values():
         if reqName != data.origRrSetKey.dname.name:
            continue
         targetDomainName = data.targetDomainName
         if not targetDomainName:
            continue
         host = srvInfoEntry.hosts.setdefault(
               targetDomainName.name, MdnsServiceHostInfo() )
         if toggleMcastDnsFloodSuppressionEnabled():
            host.conflicts = \
                  isRecordNameConflicting( targetDomainName, interface, link )
         if not detail:
            continue

         # Add SRV priority, weight and port
         srvMeta = getSrvMeta( data.rdata.data )
         if srvMeta:
            host.priority = srvMeta[ 0 ]
            host.weight = srvMeta[ 1 ]
            host.port = srvMeta[ 2 ]

         # look up target name under the same rrLink
         dnameEntryHost = database.domainName.get( targetDomainName.lower() )
         if not dnameEntryHost:
            continue
         rrLinkType = dnameEntryHost.rrLinkType.get( rrLink )
         if not rrLinkType:
            continue
         for rrType, rrLinkData in rrLinkType.rrLinkData.items():
            for rdata, rrdata in rrLinkData.data.items():
               if targetDomainName != rrdata.origRrSetKey.dname:
                  continue
               if rrType == DnsRrType.a:
                  v4Addr = getV4Addr( rrdata.rdata.data )
                  if v4Addr:
                     host.v4Addrs.append( v4Addr )
               elif rrType == DnsRrType.aaaa:
                  v6Addr = getV6Addr( rrdata.rdata.data )
                  if v6Addr:
                     host.v6Addrs.append( v6Addr )
               elif rrType == DnsRrType.txt:
                  dataBytes = list( iter( rdata.data ) )
                  host.txts.append( Record( recordBytes=dataBytes ) )
               else:
                  unknownType = \
                        host.unknownTypes.setdefault( rrType, MdnsUnknownList() )
                  dataBytes = list( iter( rdata.data ) )
                  unknownType.unknowns.append( Record( recordBytes=dataBytes ) )
      # Hosts could be empty because either there is no valid targetDomainName
      # (failed parsing on SRV rdata) or the original RrSetKey.dname does not match
      # the requested name of this command.
      if srvInfoEntry.hosts:
         record.services.append( srvInfoEntry )

   return record

#-------------------------------------------------------------------------------
# "show mdns service name <NAME> [ detail ]"
#-------------------------------------------------------------------------------
class ShowMdnsServiceName( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns service name NAME [ detail ]'
   data = {
      'mdns': mdnsMatcher,
      'service': 'Multicast DNS services',
      'name': 'Multicast DNS services name',
      'NAME': CliMatcher.PatternMatcher( CliParser.namePattern, helpname='WORD',
                                         helpdesc='Multicast DNS service name' ),
      'detail': detailMatcher
   }

   handler = showMdnsServiceNameHandler
   cliModel = MdnsServiceRecordServiceName

BasicCli.addShowCommandClass( ShowMdnsServiceName )

#-------------------------------------------------------------------------------
# "show mdns service type [ interface <INTF> ]"
#-------------------------------------------------------------------------------
class ShowMdnsServiceTypes( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns service type [ interface INTFS ]'
   data = {
      'mdns': mdnsMatcher,
      'service': 'Multicast DNS services',
      'type': 'Multicast DNS service type',
      'interface': 'Specify local interface(s)',
      'INTFS': intfRangeMatcher,
   }

   cliModel = MdnsServiceType

   @staticmethod
   def handler( mode, args ):
      model = MdnsServiceType()
      if not mcastDnsStatus.running:
         mode.addWarning( 'McastDns is not running' )
         return model

      def populateModel( model, intfKey, obsSeviceType ):
         for serviceType, ts in obsSeviceType.serviceType.items():
            entry = MdnsServiceTypeEntry( family=intfKey.family,
                                          serviceType=serviceType, timestamp=ts )
            serviceList = model.serviceTypes.setdefault( intfKey.intfId,
                                                         MdnsServiceTypeList() )
            serviceList.intfInfos.append( entry )

      intfList = args.get( 'INTFS' )
      if intfList:
         for intfName in intfList:
            for af in ( AddressFamily.ipv4, AddressFamily.ipv6 ):
               intfKey = IntfKey( IntfId( intfName ), af )
               value = mcastDnsStatus.observedServiceType.get( intfKey )
               if value:
                  populateModel( model, intfKey, value )
         return model

      for intfKey, value in mcastDnsStatus.observedServiceType.items():
         populateModel( model, intfKey, value )
      return model

BasicCli.addShowCommandClass( ShowMdnsServiceTypes )

# -------------------------------------------------------------------------------
# "show mdns service record [ name <NAME> ]"
# -------------------------------------------------------------------------------
class ShowMdnsServiceRecord( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mdns service record name NAME'
   data = {
      'mdns': mdnsMatcher,
      'service': 'Multicast DNS services',
      'record': 'Multicast DNS records',
      'name': 'Multicast DNS record name',
      'NAME': CliMatcher.PatternMatcher( CliParser.namePattern, helpname='WORD',
                                         helpdesc='Multicast DNS record name' )
   }
   cliModel = MdnsRecords

   @staticmethod
   def handler( mode, args ):
      reqName = args.get( 'NAME' )
      records = MdnsRecords()

      if not mcastDnsStatus.running:
         mode.addWarning( 'McastDns is not running' )
         return records

      database = mcastDnsStatus.database
      if not database:
         # database is none before mdns is configured
         return records

      srvType = getServiceType( reqName )
      dname = DnsName( reqName )
      dname.serviceType = srvType
      dnameEntry = database.domainName.get( dname.lower() )
      if not dnameEntry:
         return records
      records.records[ reqName.lower() ] = MdnsRecordEntryList()

      rrTypeSet = set()
      for rrLinkType in dnameEntry.rrLinkType.values():
         rrTypeSet.update( rrLinkType.rrLinkData )

      rrSetKeyList = []
      for rrtype in rrTypeSet:
         rrSetKey = RrSetKey( dname, rrtype, 1 )
         rrSetKeyList.append( rrSetKey.lower() )

      # Map rrType value to rrType string
      rrTypeValueToStr = \
            { getattr( DnsRrType, attr ): attr for attr in DnsRrType.attributes }
      for rrSetKey in rrSetKeyList:
         rrEntry = mcastDnsStatus.database.record.get( rrSetKey )
         if not rrEntry:
            continue

         recordInfos = {}
         for rrLink, rrLinkData in rrEntry.rrLinkData.items():
            interface = None
            link = None

            interface, link = getAllowedIntfLinkPerRrLink( rrLink )
            if not interface and not link:
               continue

            conflicts = \
                  isRrSetKeyConflicting( rrSetKey, interface=interface, link=link )

            # For different origName, different recordEntry objects are needed,
            # so that the locations can be added correctly
            origNameToRecordEntry = {}
            for rrData in rrLinkData.data.values():
               origName = rrData.origRrSetKey.dname.name
               origNameToRecordEntry.setdefault( origName, MdnsRecordInfo(
                  interface=interface, link=link, conflicts=conflicts ) )
               origNameToRecordEntry[ origName ].locations.append( rrData.location )

            for origName, recordInfo in origNameToRecordEntry.items():
               recordInfos.setdefault( origName, MdnsRecordInfoList() )
               recordInfos[ origName ].recordInfoList.append( recordInfo )

         recordEntryList = records.records[ reqName.lower() ].recordEntryList
         recordEntryList.append( MdnsRecordEntry(
            recordName=rrSetKey.dname.name,
            recordType=rrTypeValueToStr[ rrSetKey.type ],
            recordClass=rrSetKey.clazz,
            recordInfos=recordInfos ) )

      return records

if toggleMcastDnsFloodSuppressionEnabled():
   BasicCli.addShowCommandClass( ShowMdnsServiceRecord )

#------------------------------------------------------------------------------
# Plugin Setup
#------------------------------------------------------------------------------

def Plugin( entityManager ):
   global mcastDnsConfig
   global mcastDnsStatus
   global hwCapability

   mcastDnsConfig = ConfigMount.mount( entityManager, 'mdns/config',
                                       'McastDns::Config', 'w' )
   mcastDnsStatus = LazyMount.mount( entityManager, 'mdns/status',
                                       'McastDns::Status', 'r' )
   hwCapability = LazyMount.mount( entityManager, 'mdns/hardware/capabilities',
                                    'AleMcastDns::MdnsHwCapabilities', 'r' )
