# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import time
import datetime
from operator import attrgetter
import Tac
import TacSigint
import Ark
from ArPyUtils import naturalsorted
from CliModel import Enum
from CliModel import List
from CliModel import Int
from CliModel import Str
from CliModel import Dict
from CliModel import Model
from CliModel import Submodel
from ArnetModel import IpGenericAddress
import TableOutput

exporterKey = attrgetter( 'exporterId.address',
                          'exporterId.sourcePort',
                          'exporterId.observationDomainId' )

def createTable( headings, widths ):
   table = TableOutput.createTable( headings )
   columns = []
   for width in widths:
      f = TableOutput.Format( justify='left', minWidth=width )
      f.noPadLeftIs( True )
      f.padLimitIs( True )
      columns.append( f )
   table.formatColumns( *columns )

   return table

#--------------------------------------------------------------------------------
# EAPI Models
#--------------------------------------------------------------------------------
receivedMessagesHelp = 'Received messages'
receivedTrsHelp = 'Received template records'
receivedOtrsHelp = 'Received options template records'
receivedDrsHelp = 'Received IPFIX data records or sFlow/Postcard flow samples'
receivedOdrsHelp = 'Received options data records'
invalidMsgHdrErrorsHelp = 'Invalid message header errors'
invalidSetHdrErrorsHelp = 'Invalid set header errors'
invalidSetLenErrorsHelp = 'Invalid set length errors'
invalidSetIdErrorsHelp = 'Invalid set ID errors'
invalidTrErrorsHelp = 'Invalid template record errors'
invalidOtrErrorsHelp = 'Invalid options template record errors'
unknownTidErrorsHelp = 'Unknown template ID errors'
invalidFieldErrorsHelp = 'Invalid data record field errors'
interpretFieldErrorsHelp = 'Interpret data record field errors'
unsupportedOdrErrorsHelp = 'Unsupported options data record errors'
invalidFlowKeyErrorsHelp = 'Invalid flow key in data record errors'

class IpfixExporterId( Model ):
   address = IpGenericAddress( help='Exporter IP address' )
   sourcePort = Int( help='Exporter source port number' )
   observationDomainId = Int( help='Observation domain ID' )
   protocol = Enum( values=( 'ipfix', 'sflow', 'greent' ),
                    help='Exporter flow exporting protocol' )

   def render( self ):
      if self.protocol == 'ipfix':
         print( f'Exporter: {self.address}   Source port: {self.sourcePort}   '
                f'Observation domain ID: {self.observationDomainId}' )
      else:
         print( 'Agent:', self.address )

class IpfixExporterCounter( Model ):
   exporterId = Submodel( valueType=IpfixExporterId, help='Exporter ID' )

   receivedMessages = Int( help=receivedMessagesHelp )
   receivedTrs = Int( help=receivedTrsHelp )
   receivedOtrs = Int( help=receivedOtrsHelp )
   receivedDrs = Int( help=receivedDrsHelp )
   receivedOdrs = Int( help=receivedOdrsHelp )

   invalidMsgHdrErrors = Int( help=invalidMsgHdrErrorsHelp )
   invalidSetHdrErrors = Int( help=invalidSetHdrErrorsHelp )
   invalidSetLenErrors = Int( help=invalidSetLenErrorsHelp )
   invalidSetIdErrors = Int( help=invalidSetIdErrorsHelp )
   invalidTrErrors = Int( help=invalidTrErrorsHelp )
   invalidOtrErrors = Int( help=invalidOtrErrorsHelp )
   unknownTidErrors = Int( help=unknownTidErrorsHelp )
   invalidFieldErrors = Int( help=invalidFieldErrorsHelp )
   interpretFieldErrors = Int( help=interpretFieldErrorsHelp )
   unsupportedOdrErrors = Int( help=unsupportedOdrErrorsHelp )
   invalidFlowKeyErrors = Int( help=invalidFlowKeyErrorsHelp )

   def render( self ):
      def printCounter( name, counter ):
         print( '%s: %d' % ( name, counter ) )

      def printErrorCounter( name, counter ):
         if counter != 0:
            print( '%s: %d' % ( name, counter ) )

      self.exporterId.render()
      if self.exporterId.protocol == 'ipfix':
         printCounter( receivedMessagesHelp, self.receivedMessages )
         printCounter( receivedTrsHelp, self.receivedTrs )
         printCounter( receivedOtrsHelp, self.receivedOtrs )
         printCounter( 'Received data records', self.receivedDrs )
         printCounter( receivedOdrsHelp, self.receivedOdrs )
         printErrorCounter( invalidMsgHdrErrorsHelp, self.invalidMsgHdrErrors )
         printErrorCounter( invalidSetHdrErrorsHelp, self.invalidSetHdrErrors )
         printErrorCounter( invalidSetLenErrorsHelp, self.invalidSetLenErrors )
         printErrorCounter( invalidSetIdErrorsHelp, self.invalidSetIdErrors )
         printErrorCounter( invalidTrErrorsHelp, self.invalidTrErrors )
         printErrorCounter( invalidOtrErrorsHelp, self.invalidOtrErrors )
         printErrorCounter( unknownTidErrorsHelp, self.unknownTidErrors )
         printErrorCounter( invalidFieldErrorsHelp, self.invalidFieldErrors )
         printErrorCounter( interpretFieldErrorsHelp, self.interpretFieldErrors )
         printErrorCounter( unsupportedOdrErrorsHelp, self.unsupportedOdrErrors )
      else:
         printCounter( receivedMessagesHelp, self.receivedMessages )
         printCounter( 'Received flow samples', self.receivedDrs )
         printErrorCounter( invalidFlowKeyErrorsHelp, self.invalidFlowKeyErrors )

class IpfixAllCounter( Model ):
   receivedMessages = Int( help='Number of IPFIX/sFlow/Postcard messages received',
                           optional=True )
   sentMessages = Int( help='Number of flow information sent', optional=True )
   exporters = List( valueType=IpfixExporterCounter,
                      help='Counters for all exporters' )

   def render( self ):
      if self.receivedMessages:
         print( 'Received messages: %d' % self.receivedMessages )
         print()
      for counter in sorted( self.exporters, key=exporterKey ):
         counter.render()
         print()

class IpfixTemplateField( Model ):
   enterpriseNumber = Int( help='IANA Enterprise number' )
   informationElementId = Int( help='Information Element ID' )
   informationElementName = Str( help='Information Element name' )
   length = Int( help='Field length' )

   def row( self ):
      return ( self.enterpriseNumber, self.informationElementId,
               self.informationElementName, self.length )

class IpfixTemplateInfo( Model ):
   templateId = Int( help='Template ID' )
   scopeFields = List( valueType=IpfixTemplateField,
                       help='List of scope fields' )
   fields = List( valueType=IpfixTemplateField,
                  help='List of fields' )
   timeReceived = Int( help='Unix epoch time when template was last received' )

   def render( self ):
      print( f'Template ID: {self.templateId}   Time received:',
             time.ctime( self.timeReceived ) )

      table = createTable( ( 'Enterprise Number', 'Element ID', 'Name', 'Length' ),
                           ( 18, 11, 30, 6 ) )
      if self.fields:
         for f in self.fields:
            table.newRow( *f.row() )
      if self.scopeFields:
         table.newRow( 'Scope fields:' )
         for f in self.scopeFields:
            table.newRow( *f.row() )
      print( table.output() )

class IpfixExporterTemplate( Model ):
   exporterId = Submodel( valueType=IpfixExporterId, help='Exporter ID' )
   templateInfos = List( valueType=IpfixTemplateInfo,
                         help='Templates for this exporter' )

   def render( self ):
      self.exporterId.render()
      for templateInfo in self.templateInfos:
         templateInfo.render()

class IpfixAllTemplate( Model ):
   exporters = List( valueType=IpfixExporterTemplate,
                     help='Templates for all exporters' )

   def render( self ):
      for template in sorted( self.exporters, key=exporterKey ):
         template.render()
         print( '\n' )

class IpfixDataFields( Model ):
   fields = Dict( keyType=str, valueType=str,
                  help='Information Element name to value mapping' )

class IpfixExporterOptionData( Model ):
   exporterId = Submodel( valueType=IpfixExporterId, help='Exporter ID' )
   interfaces = Dict( keyType=str, valueType=str,
                         help='Interface ID to interface name mapping' )
   vrfs = Dict( keyType=str, valueType=str,
                help='VRF ID to interface name mapping' )
   templates = Dict( keyType=str, valueType=IpfixDataFields,
                     help='Template options data' )
   observationDomains = Dict( keyType=str, valueType=IpfixDataFields,
                              help='Observation domain options data' )

   def render( self ):
      self.exporterId.render()
      if self.interfaces:
         table = createTable( ( 'Interface ID', 'Name' ), ( 16, 16 ) )
         for intfId in naturalsorted( self.interfaces ):
            table.newRow( intfId, self.interfaces[ intfId ] )
         print( table.output() )
      if self.vrfs:
         table = createTable( ( 'VRF ID', 'Name' ), ( 16, 16 ) )
         for vrfId in naturalsorted( self.vrfs ):
            table.newRow( vrfId, self.vrfs[ vrfId ] )
         print( table.output() )
      if self.templates:
         table = createTable( ( 'Template ID', 'Element Name', 'Data' ),
                              ( 16, 30, 8 ) )
         for templateId in naturalsorted( self.templates ):
            for name, value in self.templates[ templateId ].fields.items():
               table.newRow( templateId, name, value )
         print( table.output() )
      if self.observationDomains:
         table = createTable( ( 'ObsDomain ID', 'Element Name', 'Data' ),
                              ( 16, 30, 8 ) )
         for obsDomainId in naturalsorted( self.observationDomains ):
            for name, value in (
      self.observationDomains[ obsDomainId ].
                                              fields ).items():
               table.newRow( obsDomainId, name, value )
         print( table.output() )

class IpfixAllOptionData( Model ):
   exporters = List( valueType=IpfixExporterOptionData,
                     help='Option data for all exporters' )

   def render( self ):
      for option in sorted( self.exporters, key=exporterKey ):
         option.render()
         print()

class IpfixIntInterval( Model ):
   intervalData = Dict( keyType=str, valueType=str,
         help='Mapping of Information Element name to Information Element value '
              'for an interval in an Inband Telemetry path node' )

class IpfixIntNode( Model ):
   nodeData = Dict( keyType=str, valueType=str,
         help='Mapping of Information Element name to Information Element value '
              'for an Inband Telemetry path node' )
   intIntervals = List( valueType=IpfixIntInterval,
         help='All intervals for this Inband Telemetry path node' )

DELTAOCTET = 'octetDeltaCount'
DELTAPKT = 'packetDeltaCount'
PROTO = 'protocolIdentifier'
SRCPORT = 'sourceTransportPort'
SRCIP = 'sourceIPv4Address'
DSTPORT = 'destinationTransportPort'
DSTIP = 'destinationIPv4Address'
SRCIP6 = 'sourceIPv6Address'
DSTIP6 = 'destinationIPv6Address'
STARTUS = 'flowStartMilliseconds'
STARTSEC = 'flowStartSeconds'

class IpfixFlow( Model ):
   flowKey = Dict( keyType=str, valueType=str,
         help='Mapping of Information Element name to value in Flow Key part' )
   flowData = Dict( keyType=str, valueType=str,
         help='Mapping of Information Element name to value in Flow Data part' )
   lastUpdateTime = Int( help='UTC timestamp in nano second of the last '
                              'hostname resolution' )
   intNodes = List( valueType=IpfixIntNode, optional=True,
         help='All Inband Telemetry path nodes for this flow' )

   def row( self ):
      proto = self.flowKey.get( PROTO, '' )
      srcPort = self.flowKey.get( SRCPORT, '' )
      dstPort = self.flowKey.get( DSTPORT, '' )
      srcIp = self.flowKey.get( SRCIP, '' )
      if not srcIp:
         srcIp = self.flowKey.get( SRCIP6, '' )
      dstIp = self.flowKey.get( DSTIP, '' )
      if not dstIp:
         dstIp = self.flowKey.get( DSTIP6, '' )
      deltaPkt = self.flowData.get( DELTAPKT, '' )
      deltaOctet = self.flowData.get( DELTAOCTET, '' )
      start = self.flowKey.get( STARTUS, '' )
      sec = None
      if start:
         sec = int( start ) // 1000
      else:
         startSec = self.flowKey.get( STARTSEC )
         if startSec:
            sec = int( startSec )
      if sec:
         start = datetime.datetime.fromtimestamp( sec ).\
                         strftime( '%Y-%m-%d %H:%M:%S' )
      return ( proto, srcIp, dstIp, srcPort, dstPort, deltaPkt, deltaOctet, start,
          Ark.timestampToStr( self.lastUpdateTime // ( 10 ** 9 ),
                              now=Tac.utcNow() ) )

class IpfixExporterFlow( Model ):
   exporterId = Submodel( valueType=IpfixExporterId, help='Exporter ID' )
   v4Flows = List( valueType=IpfixFlow, help='IP flows for this exporter' )
   v6Flows = List( valueType=IpfixFlow, help='IPv6 flows for this exporter' )

   def render( self ):
      self.exporterId.render()

      for group in ( 'IPv4', 'IPv6' ):
         flows = self.v6Flows if group == 'IPv6' else self.v4Flows
         if flows:
            print( '\nGroup: %s   Number of flows: %d' % ( group, len( flows ) ) )

            table = createTable(
                        ( 'Proto', 'Src IP', 'Dest IP', 'Src Port', 'Dest Port',
                          'Pkts', 'Bytes', 'Start Time', 'Last Updated' ),
                        ( 6, 25, 25, 8, 9, 5, 8, 20, 15 ) )
            sortedFlow = sorted( flows,
                                 key=attrgetter( 'lastUpdateTime' ), reverse=True )
            for flow in sortedFlow:
               table.newRow( *flow.row() )
               TacSigint.check()
            print( table.output() )

class IpfixAllFlow( Model ):
   exporters = List( valueType=IpfixExporterFlow, help='Flows for all exporters' )

   def render( self ):
      for flow in sorted( self.exporters, key=exporterKey ):
         flow.render()
         print()

class IpfixHostname( Model ):
   address = IpGenericAddress( help='IP address' )
   hostnames = List( valueType=str, help='Host names of this address' )
   lastUpdateTime = Int( help='UTC timestamp in nano second of the last '
                              'hostname resolution' )

   def rows( self ):
      names = []
      for i, hostname in enumerate( self.hostnames ):
         if i == 0:
            names.append( ( self.address, hostname,
                     Ark.timestampToStr( self.lastUpdateTime // ( 10 ** 9 ),
                                         now=Tac.utcNow() ) ) )
         else:
            names.append( ( '', hostname, '' ) )
      return names

class IpfixAllHostname( Model ):
   v4Hostnames = List( valueType=IpfixHostname,
                       help='Hostname information for all IP addresses' )
   v6Hostnames = List( valueType=IpfixHostname,
                       help='Hostname information for all IPv6 addresses' )

   def render( self ):
      for group in ( 'IPv4', 'IPv6' ):
         hostnames = self.v6Hostnames if group == 'IPv6' else self.v4Hostnames
         if hostnames:
            print( '\nGroup: %s   Number of records: %d' % ( group,
      len( hostnames ) ) )

            table = createTable( ( 'Address', 'Name', 'Last Updated' ),
                                 ( 25, 40, 15 ) )
            sortedHostnames = sorted( hostnames,
                                   key=attrgetter( 'lastUpdateTime' ), reverse=True )
            for hostname in sortedHostnames:
               for row in hostname.rows():
                  table.newRow( *row )
                  TacSigint.check()
            print( table.output() )
