# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from ArnetModel import MacAddress
from CliModel import Model, DeferredModel, Str, List, Dict, Int, Submodel
from CliPlugin.IntfCli import Intf
from CliPlugin.MacAddr import compareMacs
from CliPlugin.LldpStatusCli import chassisIdToStr
from IntfModels import Interface
from operator import attrgetter
import Tac
from TypeFuture import TacLazyType
import TableOutput

Source = TacLazyType( "Identity::NbrClassification::Source" )

@Tac.memoize
def allClassifications():
   return [ 'phone' ]

@Tac.memoize
def allClassificationsWithNone():
   return [ 'none' ] + allClassifications()

def findLldpRemoteSystem( lldpStatus, intfId, ethAddr ):
   lldpPortStatus = lldpStatus.portStatus.get( intfId )
   if not lldpPortStatus:
      return None
   for lldpRemoteSystem in lldpPortStatus.remoteSystem.values():
      if lldpRemoteSystem.ethAddr == ethAddr:
         return lldpRemoteSystem
   return None

class NbrClassificationModel( DeferredModel ):
   _ethAddr = MacAddress( help="Ethernet address of neighbor" )
   interface = Interface( help="Interface that neighbor is connected to" )
   description = Str( help="Description that helps identify neighbor" )
   classification = List( valueType=str, help="Classifications" )
   source = Str( help="Source of phone neighbor" )

   def populate( self, mode, neighborStatus, lldpStatus ):
      self._ethAddr = neighborStatus.ethAddr
      self.interface = neighborStatus.intfId
      self.source = neighborStatus.source
      lldpRemoteSystem = findLldpRemoteSystem( lldpStatus,
                                               neighborStatus.intfId,
                                               neighborStatus.ethAddr )
      if neighborStatus.source == Source.sourceCli:
         self.description = 'User defined'
      elif lldpRemoteSystem is not None:
         self.description = 'Chassis ID {}'.format(
            chassisIdToStr( lldpRemoteSystem.msap.chassisIdentifier,
                            short=False, canonical=False ) )
      else:
         self.description = ''
      if neighborStatus.classification.telephone:
         self.classification.append( 'telephone' )

class NbrClassificationsModel( Model ):
   neighbors = Dict( keyType=MacAddress, valueType=NbrClassificationModel,
         help="A dictionary of neighbors classifications keyed by ethernet address" )

   def populate( self, mode, classificationStatus, lldpStatus,
                 classification, intfs, macAddr ):
      neighbors = classificationStatus.neighbor.values()

      # Per the syntax, we can only have classification xor macAddr xor intfs.
      # Pre-filter before the model build.
      if classification:
         if classification == 'phone': # Only phone is supported for now.
            neighbors = ( n for n in neighbors if n.classification.telephone )
         else:
            neighbors = ()
            mode.addWarning( "Undefined classification '%s'" % classification )
      elif macAddr:
         neighbors = ( n for n in neighbors if compareMacs( n.ethAddr, macAddr ) )
      elif intfs:
         # Make a set (fast) of all intfs specified w/ both short and long names.
         intfs = set( intfs ).union( intfs.intfNames( shortName=True ) )
         neighbors = ( n for n in neighbors if n.intfId in intfs )

      for neighbor in neighbors:
         nbrItem = NbrClassificationModel()
         nbrItem.populate( mode, neighbor, lldpStatus )
         self.neighbors[ neighbor.ethAddr ] = nbrItem

   def render( self ):
      table = TableOutput.createTable( [ 'Port',
                                         'Mac Address',
                                         'Description',
                                         'Classifications' ] )
      fmt = TableOutput.Format( justify='left' )
      fmtWrap = TableOutput.Format( justify='left', wrap=True )
      table.formatColumns( fmt, fmt, fmtWrap, fmtWrap )
      for nbr in sorted( self.neighbors.values(), key=attrgetter( 'interface' ) ):
         # pylint: disable=protected-access
         classification = ', '.join( nbr.classification )
         classification = classification.replace( 'telephone', 'phone' )
         classification = classification or 'none'
         table.newRow( Intf.getShortname( nbr.interface ),
                       nbr._ethAddr.displayString,
                       nbr.description,
                       classification )
      print( table.output() )

class NbrClassificationSummaryModel( Model ):
   lldp = Int( help="Number of LLDP neighbors", default=0 )
   cli = Int( help="Number of CLI (manual) neighbors", default=0 )
   aaa = Int( help="Deprecated - Number of AAA neighbors", default=0 )

   def total( self ):
      return self.lldp + self.cli

   def countSource( self, source ):
      """
      Updates the counter of 'source'
      """

      if source == Source.sourceLldp:
         self.lldp += 1
      elif source == Source.sourceCli:
         self.cli += 1

class NbrClassificationsSummaryModel( Model ):
   """
   Model for `show neighbor classification summary` commands.
   Example json output:
   {
      "phone": {"lldp": 1, "cli": 2 },
      "none": {"lldp": 0, "cli": 0 }
   }
   """

   none = Submodel( valueType=NbrClassificationSummaryModel,
                    help='Summary for devices classified as none' )
   phone = Submodel( valueType=NbrClassificationSummaryModel,
                    help='Summary for devices classified as phone' )

   def populate( self, mode, classificationStatus, lldpStatus,
                 nbrsClassification, intfs, macAddr ):
      nbrs = NbrClassificationsModel()
      nbrs.populate( mode, classificationStatus, lldpStatus, nbrsClassification,
                     intfs, macAddr )

      # Create summaries
      self.none = NbrClassificationSummaryModel()
      self.phone = NbrClassificationSummaryModel()

      classifications = allClassificationsWithNone()

      for nbr in nbrs.neighbors.values():
         # Make summary for each possible classifications (including none)
         for classification in classifications:
            classificationMatches = False

            # Either 'phone' or 'telephone' is present for phone classification
            if classification == 'phone':
               classificationMatches = (
                  'phone' in nbr.classification or
                  'telephone' in nbr.classification
               )
            elif classification == 'none':
               classificationMatches = nbr.classification == []
            else:
               # General case
               classificationMatches = classification in nbr.classification

            # Count the current match in summary
            if classificationMatches:
               summary = self._getSummary( classification )
               summary.countSource( nbr.source )

   def render( self ):
      table = TableOutput.createTable( [ 'Classification',
                                         'LLDP',
                                         'CLI',
                                         'AAA',
                                         'Total' ] )
      fmtLeft = TableOutput.Format( justify='left' )
      fmtRight = TableOutput.Format( justify='right' )
      table.formatColumns( fmtLeft, fmtRight, fmtRight, fmtRight, fmtRight )
      for cls, summary in self._getSummaries():
         # Title case
         cls = cls[ 0 ].upper() + cls[ 1 : ]

         total = summary.total()
         table.newRow( cls, summary.lldp, summary.cli, summary.aaa, total )

      print( table.output() )

   def _getSummary( self, classification ):
      return getattr( self, classification, None )

   def _getSummaries( self ):
      """
      Get all classifications and summaries
      """

      summaries = []
      for classification in allClassificationsWithNone():
         summaries.append( ( classification, self._getSummary( classification ) ) )

      return summaries
