#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# Contains the CLI model for L2RIB "show commands" ( flood-set/host table ).
import Ethernet
import Tac
from CliModel import (
   Model,
   Int,
   Enum,
   Bool,
   Str,
   List,
   DeferredModel,
   Submodel )
from ArnetModel import (
   IpGenericAddr,
   IpGenericPrefix,
   MacAddress )
from IntfModels import Interface
from CliPlugin.TunnelModels import TunnelId
from CliPlugin.BridgingCliModel import _BaseTableEntry
from TableOutput import createTable, Format
from MultiRangeRule import multiRangeToCanonicalString

# Keep a global instance to get the string representation for host
# entry type.
hostEntryTypeHelper = Tac.newInstance( 'Bridging::EntryTypeHelper' )

class TableSummary( DeferredModel ):
   'Describes host entries from one source'
   name = Str( help='Table name' )
   hosts = Int( help='Number of hosts in table' )

class Summary( DeferredModel ):
   'Descibes host entries from all known sources.'
   tables = List( valueType=TableSummary, help='Number of entries in each table' )

class L2RibObject( Model ):
   # This is the base class for all L2RIB objects ( Dest, LoadBalance
   # and Label )
   spacePerLevel = Int( help='Number of spaces to indent per level',
                        default=3 )
   # There is no index for Intf dests
   index = Int( help='Object table key', optional=True )
   level = Int( help='Level in the chain of objects' )
   def getStr( self, detail ):
      assert False, "Derived class must implement this for rendering"

class LoadBalance( L2RibObject ):
   """Describes a load balance entry attributes and its strep."""
   num = Int( help='Number of load balance objects' )

   def getStr( self, detail ):
      # detail=False: "   Load Balance entry: 2-way"
      # detail=True:  "   Load Balance entry 10: 2-way"
      content = ''
      if self.level:
         content = content + ' ' * self.spacePerLevel * self.level
      content = content + 'Load Balance entry'
      # Add index to content if detail output is required.
      if detail:
         assert self.index, "Unknown load balance table index"
         content = content + ( ' %d' % self.index )
      content = content + ': %d-way' % self.num
      return content

class Label( L2RibObject ):
   """Describes L2Rib::Label attributes and its strep."""
   label = Int( help='MPLS Label' )
   def getStr( self, detail ):
      content = ''
      if self.level:
         content = content + ' ' * self.spacePerLevel * self.level
      content = content + 'Label entry'
      # Add index to content if detail output is required.
      if detail:
         assert self.index, "Unknown label table index"
         content = content + ( ' %d' % self.index )
      content = content + ': %d' % self.label
      return content

class Dest( L2RibObject ):
   """Describes a Destination entry attributes and its strep."""
   destType = Enum( values=[ 'MPLS',
                             'Tunnel',
                             'VXLAN',
                             'CPU',
                             'Interface' ],
                    help='Destination type' )
   mplsLabel = Int( help='MPLS Label', optional=True )
   tunnelId = Submodel( valueType=TunnelId, help="Tunnel Identifier", optional=True )
   tunnelEndPoint = IpGenericPrefix( help="IP prefix of the MPLS tunnel endpoint",
                                     optional=True )
   color = Int( help="Color of SR-TE policy tunnel", optional=True )
   vtepAddr = Submodel( valueType=IpGenericAddr,
                        help="IP address of VXLAN tunnel endpoint", optional=True )
   interface = Interface( help="Local interface", optional=True )
   domain = Enum( values=[ 'local', 'remote' ], help='EVPN Domain', optional=True )
   vtepType = Enum( values=[ 'controlPlaneVtep', 'dataPlaneVtep', 'remoteDomainVtep',
                             'umrVtep' ], help="VTEP type", optional=True )
   dropMode = Enum( values=[ 'none',
                             'dst',
                             'srcAndDst', ],
                    help='Drop mode',
                    optional=True )
   vlanRange = List( valueType=int, help="VLAN Range", optional=True )

   def __eq__( self, other ):
      """Ensure only unique destinations are displayed in
         'show l2rib [input <>|output] destination [floodset]' commands, so the
         destinations need to be compared and this enables that.
      """
      if self.destType != other.destType:
         return False
      if self.destType == 'VXLAN':
         return( self.vtepAddr == other.vtepAddr and
                 self.domain == other.domain and
                 self.vtepType == other.vtepType )
      elif self.destType == 'MPLS':
         return ( self.tunnelId == other.tunnelId and
                  self.tunnelEndPoint == other.tunnelEndPoint and
                  self.color == other.color and
                  self.interface == other.interface and
                  self.domain == other.domain )
      elif self.destType == 'Tunnel':
         return ( self.tunnelId == other.tunnelId and
                  self.tunnelEndPoint == other.tunnelEndPoint and
                  self.color == other.color and
                  self.interface == other.interface )
      elif self.destType == 'Interface':
         return ( self.interface == other.interface and
                  self.dropMode == other.dropMode )
      return True
   
   def __hash__( self ):
      return hash( self.destType )

   def getStr( self, detail ):
      def getTep():
         if not self.tunnelEndPoint:
            return None
         tep = str( self.tunnelEndPoint )
         if self.tunnelId and self.tunnelId.type == 'SR-TE Policy':
            tep += ', color %d' % self.color
         return tep
      content = ''
      if self.level:
         content = content + ' ' * self.spacePerLevel * self.level
      if self.destType == 'MPLS':
         underlayIntf = self.tunnelId.renderStr() if self.tunnelId else \
                        self.interface.stringValue
         assert underlayIntf is not None, "Malformed MPLS destination"
         tep = getTep()
         tunnelInfo = 'MPLS %s' % underlayIntf if not tep else \
                      f'MPLS {underlayIntf}, TEP {tep}'
         content = content + \
                   '%s, %d' % ( tunnelInfo, self.mplsLabel )
      elif self.destType == 'Tunnel':
         underlayIntf = self.tunnelId.renderStr() if self.tunnelId else \
                        self.interface.stringValue
         assert underlayIntf is not None, "Malformed Tunnel destination"
         tep = getTep()
         tunnelInfo = 'Tunnel %s' % underlayIntf if not tep else \
                      f'Tunnel {underlayIntf}, TEP {tep}'
         content = content + tunnelInfo
      elif self.destType == 'VXLAN':
         content = content + \
                     'VTEP %s' % self.vtepAddr.formatStr().stringValue
      elif self.destType == 'Interface':
         content = content + self.interface.stringValue
         if self.dropMode == 'dst':
            content = content + ' (DROP dst)'
         elif self.dropMode == 'srcAndDst':
            content = content + ' (DROP src and dst)'
      elif self.destType == 'CPU':
         content = content + 'CPU'
      else:
         raise NotImplementedError
      return content

class HostEntry( _BaseTableEntry ):
   """Base type for input and output host entries used in HostTableInput(
   Output ) Models."""
   seqNo = Int( help='Sequence number' )
   pref = Int( help='Preference for the host entry' )
   # TODO: entryType must be an Enum which aligns with
   # Bridging::EntryType Tac::Enum. The mapping must belong in
   # BridgingCliModel.py.
   entryType = Str( help='Host entry type' )
   dests = List( valueType=L2RibObject,
                 help='collection of destination objects' )
   def getStr( self ):
      inputFmt = '%s, VLAN %d, seq %d, pref %d, %s'
      return inputFmt % ( self.macAddress.displayString,
                          self.vlanId,
                          self.seqNo,
                          self.pref,
                          hostEntryTypeHelper.entryTypeStr( self.entryType ) )

class HostEntryAndSource( HostEntry ):
   source = Str( help='host entry source' )
   def getStr( self ):
      outputFmt = '%s, VLAN %d, seq %d, pref %d, %s, source: %s'
      return outputFmt % ( self.macAddress.displayString,
                           self.vlanId,
                           self.seqNo,
                           self.pref,
                           self.entryType,
                           self.source )

class HostTable( Model ):
   """Base class for rendering host table ( input and output )."""
   hosts = List( valueType=HostEntry,
                 help='Host table entries' )
   invalidHosts = List( valueType=HostEntry,
                        help='Invalid host table entries' )
   _detail = Bool( help='Detailed output of host entry',
                   optional=True )

   def render( self ):
      self._renderHosts( self.hosts )
      if self.invalidHosts:
         print( 74 * '-' )
         print( "Dangling HostTable Entries" )
         print( 74 * '-' )
      self._renderHosts( self.invalidHosts )

   def _renderHosts( self, hosts ):
      for host in hosts:
         print( host.getStr() )
         for dest in host.dests:
            print( dest.getStr( self._detail ) )

class DestsFromHostTable( Model ):
   tableName = Str( help="L2Rib Destination Flood Set table name" )
   source = Str( help="L2Rib Destination Flood Set table source", optional=True )
   dests = List( valueType=Dest, help="Destination summary" )
   destType = Enum( values=[ 'MPLS',
                             'Tunnel',
                             'VXLAN',
                             'CPU',
                             'Interface' ],
                    help='Filter by destination type',
                    optional=True )

   def getStr( self ):
      destTypes = [ self.destType ] if self.destType else \
         [ 'VXLAN', 'MPLS', 'Tunnel', 'CPU', 'Interface' ]
      destSummaryStr = getDestSummaryStr( self.dests, destTypes, self.tableName )
      if destSummaryStr:
         print( destSummaryStr )

class FloodSet( Model ):
   vlanId = Int( help='Vlan Id' )
   macAddr = MacAddress( help ='Flood set MAC address' )
   floodType = Enum( values=[ 'Any', 'All' ], help='Flood type' )
   dests = List( valueType=Dest, help='Flood destination' )

class FloodSetSummary( Model ):
   tableName = Str( help="L2Rib Flood Set table name" )
   source = Str( help="L2Rib Flood Set table source", optional=True )
   floodSets = List( valueType=FloodSet, help="Vlan flood set summary" )

   def getStr( self ):
      fsSummaryStr = "L2 RIB %s Flood Set: \n" %  self.tableName
      if self.source is not None:
         fsSummaryStr += "Source: %s\n" % self.source

      def remoteDomainExistsFunc():
         for floodSet in self.floodSets:
            for dest in floodSet.dests:
               if dest.domain == 'remote':
                  return True
         return False
      # TEP Domain column will be displayed only if
      # at least one remote domain TEP is present 
      remoteDomainExists = remoteDomainExistsFunc()
                             
      fl = Format( justify="left" )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )
      if remoteDomainExists:
         table = createTable( ( "VLAN", "Address", "Type", "Destination",
                                "TEP Domain" ), tableWidth=100 )
         table.formatColumns( fl, fl, fl, fl, fl )
      else:
         table = createTable( ( "VLAN", "Address", "Type", "Destination" ),
                              tableWidth=100 )
         table.formatColumns( fl, fl, fl, fl )
      for floodSet in self.floodSets:
         destStr = ''
         domainStr = ''
         first = True
         for dest in floodSet.dests:
            if not first:
               destStr += '\n'
               domainStr += '\n'
            first = False
            if dest.domain:
               domainStr += dest.domain
            else:
               domainStr += 'n/a'
            destStr += dest.getStr( detail=False )
         macAddrStr = Ethernet.convertMacAddrCanonicalToDisplay(
            floodSet.macAddr.stringValue )
         if remoteDomainExists:
            table.newRow( floodSet.vlanId, macAddrStr, floodSet.floodType, destStr,
                          domainStr )
         else :
            table.newRow( floodSet.vlanId, macAddrStr, floodSet.floodType, destStr )
      fsSummaryStr += table.output()
      return fsSummaryStr

class FloodSetSummaryColl( Model ):
   floodSetSummaries = List( valueType=FloodSetSummary,
                             help='Vlan flood set summary collection' )

   def render( self ):
      fsSummaryStr = ''
      for fsSummary in self.floodSetSummaries:
         fsSummaryStr += fsSummary.getStr() + '\n'
      print( fsSummaryStr )
   
def getDestSummaryStr( dests, destTypes, tableName, floodSet=False, source=None ):
   destSummary = "L2 RIB %s Destination" % tableName
   if floodSet:
      destSummary += " Flood Set"
   destSummary += ": \n"
   if source is not None:
      destSummary += "Source: %s\n" % source

   # Mapping from VTEP type to CLI friendly view
   vtepTypeStr = { 'controlPlaneVtep': 'control plane',
                   'dataPlaneVtep': 'data plane',
                   'remoteDomainVtep': 'remote domain',
                   'umrVtep': 'Unknown-MAC-Route originator' }

   def getDestDetails():
      remoteDomainExists = { destType : False for destType in destTypes }
      destTypeCount = { destType : 0 for destType in destTypes }
      displayVlan = False
      for dest in dests:
         destTypeCount[ dest.destType ] += 1 
         if dest.domain == 'remote':
            remoteDomainExists[ dest.destType ] = True
         if dest.vlanRange:
            displayVlan = True
      return ( destTypeCount, remoteDomainExists, displayVlan )
   
   # 1. Destination table for a specific destination type will be displayed only if
   # atleast one destination of that type is present.
   # 2. TEP Domain column will be displayed only if at least one remote domain TEP
   # is present
   ( destTypeCount, remoteDomainExists, displayVlan ) = getDestDetails()

   # Return empty output if there are no destinations
   if all( count == 0 for count in destTypeCount.values() ):
      return ""
   
   for destType in destTypes:
      if not destTypeCount[ destType ]:
         continue
      destSummary += "\nDestination Type: %s\n\n" % destType
      fl = Format( justify="left" )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )
      args = [ "Destination" ]
      if destType == 'VXLAN':
         args.append( "TEP Type" )
      if remoteDomainExists[ destType ]:
         args.append( "TEP Domain" )
      if displayVlan:
         args.append( "VLAN" )
      table = createTable( tuple( args ), tableWidth=100 )
      formatColumnArgs = len( args ) * [ fl ]
      table.formatColumns( *formatColumnArgs )

      for dest in dests:
         if dest.destType != destType:
            continue
         destStr = ''
         if dest.domain:
            domainStr = dest.domain
         else:
            domainStr = 'n/a'
         destStr = dest.getStr( detail=False )
         rowArgs = [ destStr ]
         if destType == 'VXLAN':
            rowArgs.append( vtepTypeStr[ dest.vtepType ] )
         if remoteDomainExists[ destType ]:
            rowArgs.append( domainStr )
         if displayVlan:
            rowArgs.append( multiRangeToCanonicalString( dest.vlanRange ) )
         table.newRow( *rowArgs )

      destSummary += table.output()
   return destSummary

class DestFloodSetSummary( Model ):
   tableName = Str( help="L2Rib Destination Flood Set table name" )
   source = Str( help="L2Rib Destination Flood Set table source", optional=True )
   dests = List( valueType=Dest, help="Destination summary" )
   destType = Enum( values=[ 'MPLS',
                             'Tunnel',
                             'VXLAN',
                             'CPU',
                             'Interface' ],
                    help='Filter by destination type',
                    optional=True )

   def getStr( self ):
      destTypes = [ self.destType ] if self.destType else \
         [ 'VXLAN', 'MPLS', 'Tunnel', 'CPU', 'Interface' ]
      destFsSummaryStr = getDestSummaryStr( self.dests, destTypes, self.tableName,
                                            floodSet=True, source=self.source )
      return destFsSummaryStr

class DestSummaryColl( Model ):
   destSummaries = List( valueType=DestsFromHostTable,
                         help='Destination floodset summary collection' )

   def render( self ):
      destSummaryStr = ''
      for destSummary in self.destSummaries:
         destFsStr = destSummary.getStr()
         if destFsStr:
            destSummaryStr += destFsStr + '\n'
      if destSummaryStr:
         print( destSummaryStr )

class DestFloodSetSummaryColl( Model ):
   destFloodSetSummaries = List( valueType=DestFloodSetSummary,
                                 help='Destination floodset summary collection' )
   
   def render( self ):
      dstFsSummaryStr = ''
      for dstFsSummary in self.destFloodSetSummaries:
         dstFsStr = dstFsSummary.getStr()
         if dstFsStr:
            dstFsSummaryStr += dstFsStr + '\n'
      if dstFsSummaryStr:
         print( dstFsSummaryStr )
