#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from Arnet import IpGenAddr
from ArnetModel import IpGenericAddress
from CliModel import Bool
from CliModel import Dict
from CliModel import Enum
from CliModel import Float
from CliModel import Int
from CliModel import Model
from CliModel import Str
from IntfModels import Interface
import IntfModels
import datetime
import TableOutput
import Tac

fl = TableOutput.Format( justify='left' )
fl.padLimitIs( True )

PathRouteState = Tac.Type( 'Dps::PathRouteState' )
PathRouteStateHelper = Tac.Type( 'Dps::DpsPathEnumStringHelper' )

class DpsLbProfiles( Model ):
   appProfile = Str( help='Name of the application profile' )
   vrf = Str( help='Name of the VRF' )
   peerIp = IpGenericAddress( help='Peer IP address of the ITS session' )
   peerName = Str( help='Name of the peer' )

   # Aggregated counters
   outBytes = Int( help='Number of out bytes for the application' )
   outPkts = Int( help='Number of out packets for the application' )
   throughput = Float( help='Throughput in Mbps for the application' )
   flows = Int( help='Number of flows for the application' )

   class PathGroups( Model ):
      # Aggregated counters
      # outBytes = Int( help='Number of bytes going via the path group' )
      # outPkts = Int( help='Number of packets going via the path group' )
      # throughput = Int( help='Throughput in Mbps of the path group' )
      # flows = Int( help='Number of flows going through the path group' )

      class Paths( Model ):
         outBytes = Int( help='Number of bytes going via the path' )
         outPkts = Int( help='Number of packets going via the path' )
         throughput = Float( help='Throughput in Mbps of the path' )
         flows = Int( help='Number of flows going through the path' )

      # key is pathId
      paths = Dict( keyType=int,
                    valueType=Paths,
                    help="A mapping of path index to its counters",
                    optional=True )

   # key is groupName
   pathGroups = Dict( keyType=str,
                      valueType=PathGroups,
                      help="A mapping of group name to its counters",
                      optional=True )

class DpsLbCounters( Model ):
   # key is lbId
   lbProfiles = Dict( keyType=int,
                      valueType=DpsLbProfiles,
                      help="A mapping of load-balance profile index" \
                      " to its path selection counters" )
   _detail = Bool( help='Show detail information' )

   def render( self ):
      if self._detail:
         # show path-selection load-balance counters detail
         header = [ 'App Profile', 'Vrf', 'Peer', 'Path Group', 'Path',
                    'Flows', 'Throughput(Mbps)', 'Out Bytes', 'Out Packets' ]
      else:
         # show path-selection load-balance counters
         header = [ 'App Profile', 'Vrf', 'Peer', 'Path Group', 'Path',
                    'Flows', 'Throughput(Mbps)' ]
      table = TableOutput.createTable( header, tableWidth=160 )
      table.formatColumns( *( fl for _ in header ) )

      for lbProfile in self.lbProfiles.values():
         peer = lbProfile.peerIp.stringValue + ( ' (' + lbProfile.peerName + ')' \
                          if lbProfile.peerName else '' )
         for groupName, group in sorted( lbProfile.pathGroups.items() ):
            for pathId, path in sorted( group.paths.items() ):

               if self._detail:
                  table.newRow( lbProfile.appProfile, lbProfile.vrf, peer,
                                groupName, "path" + str ( pathId ), path.flows,
                                path.throughput, path.outBytes, path.outPkts )
               else:
                  table.newRow( lbProfile.appProfile, lbProfile.vrf, peer,
                                groupName, "path" + str ( pathId ), path.flows,
                                path.throughput )
      print( table.output() )

class DpsAppCounters( Model ):
   # key is lbId
   lbProfiles = Dict( keyType=int,
                      valueType=DpsLbProfiles,
                      help="A mapping of load-balance profile index" \
                      " to its application counters" )

   def render( self ):
      # show path-selection application counters
      header = [ 'App Profile', 'Vrf', 'Peer', 'Throughput(Mbps)',
                 'Out Bytes', 'Out Packets' ]
      table = TableOutput.createTable( header, tableWidth=120 )
      table.formatColumns( *( fl for _ in header ) )
      for lbProfile in self.lbProfiles.values():
         peer = lbProfile.peerIp.stringValue + ( ' (' + lbProfile.peerName + ')' \
                                     if lbProfile.peerName else '' )
         table.newRow( lbProfile.appProfile, lbProfile.vrf, peer,
                       lbProfile.throughput, lbProfile.outBytes, lbProfile.outPkts )
      print( table.output() )

class DpsSession( Model ):
   active = Bool ( help='The telemetry session state is active' )
   seconds = Int( help='Number of seconds the telemetry session has'
                  ' been in this state' )

class DpsPath( Model ):
   source = Str( help='Source IP address of the path group' )
   destination = Str( help='Destination IP address of the path group' )
   sourcePort = Int( help='Source UDP port of the DPS path', optional=True )
   destinationPort = Int( help='Destination UDP port of the DPS path',
                          optional=True )
   sourceWanId = Int( help='Source WAN ID of the DPS path', optional=True )
   sourceWanIdPathGroup = Str(
      help='Path group of the source WAN ID of the DPS path', optional=True )
   destinationWanId = Int( help='Destination WAN ID of the DPS path',
                           optional=True )
   destinationWanIdPathGroup = Str(
      help='Path group of the destination WAN ID of the DPS path', optional=True )
   localIntf = Interface( help='Local interface name of this path', optional=True )
   state = Enum( values=( PathRouteState.routeStateInvalid,
                          PathRouteState.ipsecEstablished,
                          PathRouteState.routeResolved,
                          PathRouteState.arpPending,
                          PathRouteState.ipsecPending,
                          PathRouteState.routePending ),
                 help='Routing state of this path' )
   dpsSessions = Dict( keyType=int,
                       valueType=DpsSession,
                       help="A mapping of traffic class to DPS session" )
   pathType = Enum( values=( 'static', 'dynamic' ),
                    help='Type of the path' )
   mtu = Int( help='Path MTU', optional=True )
   mtuConfigType = Enum( values=( 'static', 'disabled', 'auto' ),
                         help='Type of MTU config', optional=True )
   mtuState = Enum( values=( 'notApplicable', 'init', 'probing', 'done', 'error' ),
                    help='MTU discovery state', optional=True )
   mtuErrStateReason = Enum( values=( 'none', 'response message timeout',
                                      'path down event',
                                      'mtu outside defined range' ),
                             help='MTU error state reason', optional=True )
   mtuDiscDueSecs = Int( help='MTU discovery due in secs', optional=True )
   mtuDiscTimeTakenMsecs = Float( help='MTU discovery time taken in msecs',
                                  optional=True )
   mtuDiscIntervalSecs = Int( help='MTU discovery interval in secs', optional=True )
   mtuNumProbesSent = Int( help='Number of Probes sent', optional=True )
   mtuNumReportReqSent = Int( help='Number of Report Requests sent',
                              optional=True )
   mtuNumReportRespRcvd = Int( help='Number of Report Responses received',
                               optional=True )
   mtuNumRespProcErr = Int( help='Number of Reponse processing error',
                            optional=True )
   mtuNumProbesRcvd = Int( help='Number of Probes received', optional=True )
   mtuNumReportReqRcvd = Int( help='Number of Report Requests received',
                              optional=True )
   mtuNumReportRespSent = Int( help='Number of Report Responses sent',
                               optional=True )
   mtuNumReqProcErr = Int( help='Number of Request processing error',
                           optional=True )
   label = Int( help='Label value of the DPS path', optional=True )

   def setAttrsFromDict( self, data ):
      for tc, session in data.pop( "dpsSessions", {} ).items():
         submodel = DpsSession()
         submodel.setAttrsFromDict( session )
         self.dpsSessions[ int( tc ) ] = submodel
      super().setAttrsFromDict( data )

class DpsGroup( Model ):
   dpsPaths = Dict( keyType=str,
                    valueType=DpsPath,
                    help='A mapping of path name to its state' )

   def setAttrsFromDict( self, data ):
      if "dpsPaths" in data:
         for pathName, paths in data[ "dpsPaths" ].items():
            self.dpsPaths[ pathName ] = DpsPath()
            self.dpsPaths[ pathName ].setAttrsFromDict( paths )

class DpsPeer( Model ):
   peerName = Str( help='Peer name' )
   avtRegionId = Int( help='Adaptive virtual topology region ID', optional=True )
   avtZoneId = Int( help='Adaptive virtual topology zone ID', optional=True )
   avtSiteId = Int( help='Adaptive virtual topology site ID', optional=True )
   dpsGroups = Dict( keyType=str,
                     valueType=DpsGroup,
                     help='A mapping of group name to its paths' )

   def setAttrsFromDict( self, data ):
      self.peerName = data[ "peerName" ]
      if "dpsGroups" in data:
         for pgName, paths in data[ "dpsGroups" ].items():
            self.dpsGroups[ pgName ] = DpsGroup()
            self.dpsGroups[ pgName ].setAttrsFromDict( paths )

class DpsPaths( Model ):
   dpsPeers = Dict( keyType=IpGenericAddress,
                    valueType=DpsPeer,
                    help='A mapping of peer IP to peer path groups' )
   _detail = Bool( help='Show detail information' )
   _mtuInfoDisplayMode = Enum( values=( 'no mtu', 'mtu only',
                                        'mtu summary', 'mtu detail' ),
                               help='Determines what MTU Info is displayed' )

   def setAttrsFromDict( self, data ):
      if "dpsPeers" in data:
         for peerIp, peers in data[ "dpsPeers" ].items():
            self.dpsPeers[ IpGenAddr( peerIp ) ] = DpsPeer()
            self.dpsPeers[ IpGenAddr( peerIp ) ].setAttrsFromDict( peers )

   def getSessionState( self, sess ):
      if sess.active:
         return 'active (' + str(
            datetime.timedelta( seconds=sess.seconds ) ) + ')'
      return 'inactive'

   def getMtuDiscDueStr( self, path ):
      if path.mtuDiscDueSecs:
         return str( datetime.timedelta( seconds=path.mtuDiscDueSecs ) )
      else:
         return "-"

   def renderTable( self ):
      header = [ "Peer", "Path Group", "Source",
                 "Destination", "Path Name", "Type", "TC",
                 "Route State", "Telemetry State" ]
      if self._mtuInfoDisplayMode == 'mtu only':
         header.append( "MTU" )
      elif self._mtuInfoDisplayMode == 'mtu summary':
         header = [ "Peer", "Path Group", "Source",
                    "Destination", "Path Name", "MTU",
                    "Type", "State", "Discovery Due in" ]
      table = TableOutput.createTable( header, tableWidth=160 )
      table.formatColumns( *( fl for _ in header ) )

      for peerIp in sorted( self.dpsPeers, key=IpGenAddr ):
         peer = self.dpsPeers[ peerIp ]
         for groupName, group in sorted( peer.dpsGroups.items() ):
            for pathName, path in sorted( group.dpsPaths.items() ):
               for tc, sess in sorted( path.dpsSessions.items() ):
                  peerName = peerIp + ( ' (' + peer.peerName + ')' \
                                        if peer.peerName != '' else '' )
                  sessState = self.getSessionState( sess )
                  convHelper = PathRouteStateHelper()
                  pathState = convHelper.pathRouteStateToStr( path.state )

                  if self._mtuInfoDisplayMode == 'mtu only':
                     table.newRow( peerName, groupName, path.source,
                                   path.destination, pathName, path.pathType, tc,
                                   pathState, sessState, path.mtu )
                  elif self._mtuInfoDisplayMode == 'mtu summary':
                     table.newRow( peerName, groupName, path.source,
                                   path.destination, pathName, path.mtu,
                                   path.mtuConfigType, path.mtuState,
                                   self.getMtuDiscDueStr( path ) )
                  else:
                     table.newRow( peerName, groupName, path.source,
                                   path.destination, pathName, path.pathType, tc,
                                   pathState, sessState )
      print( table.output() )

   def displayMtuDetailInfo( self, path ):
      print( 'MTU: %d' % path.mtu )
      if path.mtuConfigType == 'disabled':
         print( 'Type: %s (Default DPS MTU)' % path.mtuConfigType )
      else:
         print( 'Type: %s' % path.mtuConfigType )
      if path.mtuConfigType == 'auto':
         print( 'State: %s' % path.mtuState )
         if path.mtuState == 'error':
            print( 'Err State Reason: %s' % path.mtuErrStateReason )
         print( 'Time taken to discover: %.2f ms' %
                path.mtuDiscTimeTakenMsecs )
         print( 'Discovery due in: %s' % self.getMtuDiscDueStr( path ) )
         print( 'Discovery interval: %d secs' % path.mtuDiscIntervalSecs )
         print( 'Transmit side stats:' )
         print( 'Num probes sent: %d' % path.mtuNumProbesSent )
         print( 'Num report requests sent: %d' %
                path.mtuNumReportReqSent )
         print( 'Num report responses received: %d' %
                path.mtuNumReportRespRcvd )
         print( 'Num response processing error: %d' %
                path.mtuNumRespProcErr )
      print( 'Receive side stats:' )
      print( 'Num probes received: %d' % path.mtuNumProbesRcvd )
      print( 'Num report requests received: %d' %
             path.mtuNumReportReqRcvd )
      print( 'Num report responses sent: %d' %
             path.mtuNumReportRespSent )
      print( 'Num request processing error: %d' %
             path.mtuNumReqProcErr )

   def renderDetail( self ):
      for peerIp in sorted( self.dpsPeers, key=IpGenAddr ):
         peer = self.dpsPeers[ peerIp ]
         peerName = peerIp + ( ' (' + peer.peerName + ')'
                               if peer.peerName != '' else '' )
         peerLine = 'Peer: %s' % peerName
         if peer.avtRegionId:
            peerLine += ', Region: %d' % peer.avtRegionId
         if peer.avtZoneId:
            peerLine += ', Zone: %d' % peer.avtZoneId
         if peer.avtSiteId:
            peerLine += ', Site: %d' % peer.avtSiteId

         for groupName, group in sorted( peer.dpsGroups.items() ):
            for pathName, path in sorted( group.dpsPaths.items() ):
               remotePg = path.destinationWanIdPathGroup
               localPg = path.sourceWanIdPathGroup
               print( f'{peerLine}, Path group {groupName}' )
               pathNameLine = f'Path name: {pathName}, {path.pathType}'
               if path.label:
                  pathNameLine += ', Label: %d' % path.label
               print( pathNameLine )
               sourceLine = 'Source: %s' % path.source
               if path.sourcePort:
                  sourceLine += ', Port: %d' % path.sourcePort
               if path.sourceWanId:
                  importLine = ''
                  if localPg != groupName:
                     importLine = ' (Imported from path group %s)' % localPg
                  sourceLine += ', WAN ID: %d%s' % ( path.sourceWanId, importLine )
               print( sourceLine )
               dstLine = 'Destination: %s' % path.destination
               if path.destinationPort:
                  dstLine += ', Port: %d' % path.destinationPort
               if path.destinationWanId:
                  importLine = ''
                  if remotePg != groupName:
                     importLine = ' (Imported from path group %s)' % remotePg
                  dstLine += ', WAN ID: %d%s' % ( path.destinationWanId, importLine )
               print( dstLine )
               if path.localIntf:
                  print( 'Local interface: ' + path.localIntf.stringValue )
               convHelper = PathRouteStateHelper()
               print( 'Route state: %s' % convHelper.pathRouteStateToStr(
                                                path.state ) )
               for tc, sess in sorted( path.dpsSessions.items() ):
                  print( 'Traffic class %d: %s' % ( tc,
      self.getSessionState( sess ) ) )
               if self._mtuInfoDisplayMode == 'mtu only':
                  print( 'MTU: %d' % path.mtu )
               elif self._mtuInfoDisplayMode == 'mtu detail':
                  self.displayMtuDetailInfo( path )
               print( '' )

   def render( self ):
      if self._detail:
         self.renderDetail()
      else:
         self.renderTable()

class DpsPathGroupExport( Model ):
   ipAddress = IpGenericAddress( help='Exported IP address' )
   port = Int( help='Exported port number' )
   intf = IntfModels.Interface( help='Local interface' )
   ipSource = Enum( values=( 'interface', 'intfpublic', 'stun' ),
                    help='Source of the IP address' )

class DpsExport( Model ):
   """Export information for a WAN ID under all path groups

   With the introduction of WAN ID sharing, a WAN ID can be imported in multiple
   path groups and it will be exported from all those path groups.

   To avoid revisioning, it was decided that the old attributes will remain, but are
   deprecated in favor of "pathGroups" attribute. The deprecated fields contain
   the export information of the intf in the path group where it is configured
   while the "pathGroups" dict contains the export information of the intf in all
   path groups, configured or imported.
   """
   pathGroup = Str( help='Path group name where the interface is configured '
                         '(deprecated; use pathGroups)' )
   ipAddress = IpGenericAddress(
      help='Exported IP address in the path group where the interface is '
           'configured (deprecated; use pathGroups)' )
   port = Int(
      help='Exported port number in the path group where the interface is '
           'configured (deprecated; use pathGroups)' )
   intf = IntfModels.Interface( help='Local interface (deprecated; use pathGroups)' )
   ipSource = Enum( values=( 'interface', 'intfpublic', 'stun' ),
                    help='Source of the IP address (deprecated; use pathGroups)' )
   pathGroups = Dict( keyType=str,
                      valueType=DpsPathGroupExport,
                      help='A mapping of path group name to its exported endpoints' )

class DpsPathsExport( Model ):
   wans = Dict( keyType=int,
                valueType=DpsExport,
                help='A mapping of WAN ID to exported endpoint' )
   vtep = IpGenericAddress( help='VTEP IP', optional=True )

   def ipSourceDisplay( self, export ):
      if export.ipSource == 'interface':
         return "Interface"
      elif export.ipSource == 'intfpublic':
         return "Public IP"
      elif export.ipSource == 'stun':
         return "STUN"
      return 'Unknown'

   def render( self ):
      header = [ "Path Group", "WAN ID", "Interface",
                 "IP Address-Port", "IP Source" ]
      table = TableOutput.createTable( header, tableWidth=160 )
      table.formatColumns( *( fl for _ in header ) )
      for wanId, export in sorted( self.wans.items() ):
         for pgName, pgExport in export.pathGroups.items():
            endpoint = f"{pgExport.ipAddress}:{pgExport.port}"
            table.newRow( pgName, wanId, pgExport.intf.stringValue,
                          endpoint, self.ipSourceDisplay( pgExport ) )
      if self.vtep:
         print( "VTEP: %s" % self.vtep )
      print( table.output() )
