#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import TableOutput
from datetime import datetime
from IntfModels import Interface
from Intf.IntfRange import intfListToCanonical
from CliModel import (
   Bool,
   Dict,
   Float,
   Int,
   List,
   Model,
   Str,
   Submodel,
)
from CliPlugin.SftModel import (
   FlowDetailModel,
   TrackingModel,
)
from CliPlugin.TrafficPolicyCliModel import Rule
import Toggles.InbandTelemetryCommonToggleLib
from operator import attrgetter

class Profile( Model ):
   profileType = Str( help="The type of profile" )
   sampleRate = Int( help="Sample rate of the profile", optional=True )
   if Toggles.InbandTelemetryCommonToggleLib.\
      toggleFeatureInbandTelemetrySamplePolicyEnabled():
      samplePolicy = Str( help="Sample policy", optional=True )
   egressCollection = Str( help="The collector name" )
   egressDrop = Str( help="Whether egress drop is enabled or disabled",
                     optional=True )
   profileStatus = Str(
      help="Whether the port is active or inactive or in error condition" )
   profileErrorReason = Str( help="None or the reason for error" )

class IntfList( Model ):
   Intfs = List( valueType=Interface, help='Interface list', optional=True )

class InbandTelemetryProfiles( Model ):
   coreProfiles = Dict( keyType=str, valueType=Profile,
   help="Maps core profiles to their corresponding configuration.", optional=True )
   edgeProfiles = Dict( keyType=str, valueType=Profile,
   help="Maps edge profiles to their corresponding configuration.", optional=True )
   vxlanProfiles = Dict( keyType=str, valueType=Profile,
   help="Maps vxlan profiles to their corresponding configuration.", optional=True )
   transitOnly = Bool( help="InbandTelemetry transit only mode" )
   intVersion = Str( help="InbandTelemetry version" )

   def renderAttributes( self, profiles ):
      for name in profiles:
         profile = profiles[ name ]
         if profile.profileType == 'edge':
            profileType = 'Edge'
         elif profile.profileType == 'core':
            profileType = 'Core'
         else:
            profileType = 'Vxlan'
         print( profileType + " profile: " + name )
         if profile.sampleRate:
            print( "Ingress sample rate: " + str( profile.sampleRate ) )
         if Toggles.InbandTelemetryCommonToggleLib.\
            toggleFeatureInbandTelemetrySamplePolicyEnabled():
            if profile.samplePolicy:
               print( "Ingress sample policy: " + str( profile.samplePolicy ) )
         # uncomment when the sample policy feature is enabled
         # if profile.samplePolicy:
         # print "Ingress sample policy: " + profile.samplePolicy
         collectionStr = ""
         if profileType == 'Vxlan':
            collectionStr = "Ingress collection: "
         else:
            collectionStr = "Egress collection: "
         print( collectionStr + profile.egressCollection )
         if profile.egressDrop:
            print( "Egress drop: " + profile.egressDrop )
         print( "Profile status: " + profile.profileStatus )
         print( "Profile error reason: " + profile.profileErrorReason + "\n" )

   def render( self ):
      if self.coreProfiles:
         self.renderAttributes( self.coreProfiles )
      # On Th3/Th4 only core ports are supported
      if not self.transitOnly and self.edgeProfiles:
         self.renderAttributes( self.edgeProfiles )
      if self.vxlanProfiles:
         self.renderAttributes( self.vxlanProfiles )

class SamplePolicyModel( Model ):
   rules = List( valueType=Rule, help="Detailed information of match rules" )

   def render( self ):
      print( "Total number of rules configured: %d" % len( self.rules ) )
      for rule in self.rules:
         print( f"match {rule.matchOption} {rule.ruleString}:" )
         dscpStr = ""
         if rule.matches.dscps:
            for dscpRange in rule.matches.dscps:
               if dscpRange.low == dscpRange.high:
                  dscpStr += "%d, " % dscpRange.high
               else:
                  dscpStr += "%d-%d, " % ( dscpRange.low, dscpRange.high )
            dscpStr = dscpStr.rstrip( " " )
            dscpStr = dscpStr.rstrip( "," )
            print( "\tDSCP: %s" % dscpStr )
         actions = rule.actions
         if actions.sample or actions.sampleAll:
            action = "sample" if actions.sample else "sample all"
            print( "\tActions: %s" % action )

class InbandTelemetry( Model ):
   enabled = Bool( help="Inband telemetry is enabled on the switch" )
   deviceId = Str( help="Device ID of the switch" )
   probeMarker = Int( help="Configured probe marker for Inband telemetry" )
   udpDstPorts = List( valueType=int,
                     help="Configured udp destination ports for INT packets" )
   probeProtocol = Int( help="IP header protocol/next-header field for INT packets" )
   operStatus = Str( help="Operational status of Inband telemetry" )
   inProgressReason = Str( help="Reason for operational status in progress" )
   policies = Dict( keyType=str, valueType=SamplePolicyModel,
                    help="Maps sample policy name to its configuration",
                    optional=True )
   profiles = Submodel( valueType=InbandTelemetryProfiles,
                        help="Core, Edge and Vxlan profiles" )
   intDetectionMethod = Str( help="Default Inband telemetry detection method" )
   intVersion = Str( help="InbandTelemetry version" )

   def render( self ):
      if self.enabled:
         print( "Enabled: True" )
      else:
         print( "Enabled: False" )
      print( "Device ID: %s" % self.deviceId )
      if self.intDetectionMethod == \
            Tac.Type( "InbandTelemetry::IntDetectionMethod" ).ProbeMarkerBased:
         markerA = self.probeMarker >> 32
         markerB = self.probeMarker - ( markerA << 32 )
         hexA = hex( markerA )[ : -1 ] if 'L' in hex( markerA ) else hex( markerA )
         hexB = hex( markerB )[ : -1 ] if 'L' in hex( markerB ) else hex( markerB )
         print( f"Probe Marker: {hexA} {hexB}" )
         if self.intVersion == \
               Tac.Type( "InbandTelemetry::IntVersion" ).VersionNativeV1:
            print( "Probe Marker UDP Destination Ports:", ", ".join(
                                    str( port ) for port in self.udpDstPorts ) )
      else:
         print( "Probe IP Protocol: %s" % self.probeProtocol )
      print( "Operational Status: %s" % self.operStatus )
      if self.inProgressReason:
         print( "Reason: %s" % self.inProgressReason )

      if self.policies:
         print( "\nSample policies:" )
         for policy in self.policies:
            print( "Sample policy %s" % policy )
            self.policies[ policy ].render()
            print( "\n" )
      print( "\nProfiles:" )
      self.profiles.render()

class ModelIntProfileSummary( Model ):
   coreIntfList = Dict( keyType=str,
                  valueType=IntfList,
                  help='Core profiles',
                  optional=True )
   edgeIntfList = Dict( keyType=str,
                  valueType=IntfList,
                  help='Edge profiles',
                  optional=True )
   vxlanIntfList = Dict( keyType=str,
                   valueType=IntfList,
                   help='Vxlan profiles',
                   optional=True )

   def render( self ):
      def renderList( intfList ):
         for profName, inList in sorted( intfList.items() ):
            print( '{}: {}'.format( profName,
                  ','.join( intfListToCanonical( sorted( inList.Intfs ) ) ) ) )
      if self.coreIntfList:
         print( 'Core profiles' )
         renderList( self.coreIntfList )
      if self.edgeIntfList:
         if self.coreIntfList:
            print( '\nEdge profiles' )
         else:
            print( 'Edge profiles' )
         renderList( self.edgeIntfList )
      if self.vxlanIntfList:
         if self.coreIntfList or self.edgeIntfList:
            print( '\nVxlan profiles' )
         else:
            print( 'Vxlan profiles' )
         renderList( self.vxlanIntfList )

class IntFlowDetailIntervalStats( Model ):
   timestamp = Float( help="Interval start time" )
   pkts = Int( help="Number of packets" )
   congestions = List( valueType=bool, help="Congestion per device in path" )
   avgLatencies = List( valueType=int,
         help="Average latency per device in path (ns)" )
   maxLatencies = List( valueType=int,
         help="Maximum latency per device in path (ns)" )
   minLatencies = List( valueType=int,
         help="Minimum latency per device in path (ns)" )

class IntFlowDetailNodeInfo( Model ):
   deviceId = Int( help="Device ID" )
   ingressPortId = Int( help="Ingress port" )
   egressPortId = Int( help="Egress port ID" )
   egressQueueId = Int( help="Egress queue ID" )
   ttl = Int( help="Inband telemetry TTL (hop count)" )
   lastPacketCongestion = Bool(
         help="Congestion detected in the last sampled packet" )
   lastPacketLatency = Int( help="Last packet latency (ns)" )

class IntFlowDetailModel( FlowDetailModel ):
   pathTransistions = Int( help="Path transistions" )
   pathPackets = Int( help="Path packets" )
   devicesInPath = Int( help="Devices in path" )
   flowIntervals = Int( help="Flow intervals" )
   hopCountExceeded = Bool( help="Hop count exceeded" )
   devicesInformation = List( valueType=IntFlowDetailNodeInfo,
                          help="List of inband telemetry device information" )
   flowIntervalStats = List( valueType=IntFlowDetailIntervalStats,
                          help="List of inband telemetry interval statistics" )

class IntTrackingModel( TrackingModel ):

   def renderIntFlowDetailNodeInfo( self, intDetail ):
      if not intDetail.devicesInformation:
         return
      nodeInfoHeadings = ( "Device ID", "Ingress Port ID", "Egress Port ID",
            "Egress Queue ID", "TTL", "Congestion (last pkt)",
            "Latency (last pkt) (ns)" )
      formatCommon = TableOutput.Format( justify="right", maxWidth=7, wrap=True )
      formatTTL = TableOutput.Format( justify="right", maxWidth=5, wrap=True )
      formatCongestion = TableOutput.Format( justify="left", maxWidth=13, wrap=True )
      nodeInfoTable = TableOutput.createTable( nodeInfoHeadings, indent=6 )
      nodeInfoTable.formatColumns( formatCommon, formatCommon, formatCommon,
            formatCommon, formatTTL, formatCongestion, formatCommon )
      pathIDList = []
      for nodeInfo in intDetail.devicesInformation:
         pathIDList.append( str( hex( nodeInfo.deviceId ) ) )
         lastPacketCongestion = "not congested"
         if nodeInfo.lastPacketCongestion:
            lastPacketCongestion = "congested"
         nodeInfoTable.newRow( hex( nodeInfo.deviceId ),
                               hex( nodeInfo.ingressPortId ),
                               hex( nodeInfo.egressPortId ),
                               nodeInfo.egressQueueId,
                               nodeInfo.ttl,
                               lastPacketCongestion,
                               nodeInfo.lastPacketLatency )
      tab = " " * 6
      pathIDStr = ' -> '.join( pathIDList )
      print( f"{tab}Path device IDs: {pathIDStr}" )
      print( nodeInfoTable.output() )

   def renderIntFlowDetailIntervalStats( self, intDetail ):
      tab = " " * 6
      if not intDetail.flowIntervalStats:
         return
      intrvlStatsHeadings = ( "Time Stamp", "Pkts", "Congestion Per Device In Path",
            ( "Latency Per Device In Path (ns)", ( "Avg", "Max", "Min" ) ) )
      intrvlStatsTable = TableOutput.createTable( intrvlStatsHeadings, indent=6 )
      formatTS = TableOutput.Format( justify="right", maxWidth=10, wrap=True )
      formatPkts = TableOutput.Format( justify="right", maxWidth=5, wrap=True )
      formatCongestion = TableOutput.Format( justify="left", maxWidth=10, wrap=True )
      formatLatencyTime = TableOutput.Format( justify="right", maxWidth=10,
            wrap=True )
      intrvlStatsTable.formatColumns( formatTS, formatPkts, formatCongestion,
            formatLatencyTime, formatLatencyTime, formatLatencyTime )
      intervalStats = sorted( intDetail.flowIntervalStats,
            key=attrgetter( 'timestamp' ), reverse=True )
      for intrvlStat in intervalStats:
         congestion = [ 'c' if c else '-' for c in intrvlStat.congestions ]
         congestionTup = tuple( congestion )
         congestionStr = str( congestionTup ).replace( "'", "" )
         startTime = datetime.fromtimestamp(
               intrvlStat.timestamp ).strftime( "%Y-%m-%d %H:%M:%S" )
         intrvlStatsTable.newRow( startTime, intrvlStat.pkts, congestionStr,
               str( tuple( intrvlStat.avgLatencies ) ).replace( 'L', '' ),
               str( tuple( intrvlStat.maxLatencies ) ).replace( 'L', '' ),
               str( tuple( intrvlStat.minLatencies ) ).replace( 'L', '' ) )
      print( "%sFlow interval statistics" % ( tab ) )
      print( intrvlStatsTable.output() )

   def renderIntFlowDetail( self, detail ):
      tab = " " * 6
      print( "%sInband telemetry information" % ( tab ) )
      print( "%sPath transitions: %d, Path packets: %d, Hop count exceeded: "
            "%s" % ( tab, detail.pathTransistions, detail.pathPackets,
                  str( detail.hopCountExceeded ).lower() ) )
      print( "%sDevices in path: %d, Flow intervals: %d" %
             ( tab, detail.devicesInPath, detail.flowIntervals ) )

      self.renderIntFlowDetailNodeInfo( detail )
      self.renderIntFlowDetailIntervalStats( detail )

   def renderFlowDetail( self, flow, key ):
      super().renderFlowDetail( flow, key )
      self.renderIntFlowDetail( flow.flowDetail )
