# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from CliModel import Model, Dict, Float, Str, Int, Bool
import AclCliLib
from ArnetModel import IpGenericAddress
import TableOutput
from Arnet import IpGenAddr
from TypeFuture import TacLazyType
import Tac
import sys
from HumanReadable import formatTimeInterval
from enum import Enum

AddressFamily = TacLazyType( "Arnet::AddressFamily" )
DscpVal = TacLazyType( "Arnet::DscpValue" )

class IcmpStat( Model ):
   latency = Float( help="Network latency in milliseconds to reach the host using" +
                         "the ICMP protocol",
                    optional=True )
   jitter = Float( help="Packet delay variance observed while probing the host",
                   optional=True )
   packetLoss = Int( help="Percentage packet loss when reaching the host" )
   icmpProbeError = Str( help="ICMP probes error message", optional=True )

   def fromTacc( self, stat ):
      if stat.latency < sys.float_info.max:
         self.latency = round( stat.latency, 3 )
      if stat.jitter < sys.float_info.max:
         self.jitter = round( stat.jitter, 3 )
      self.packetLoss = stat.packetLoss
      if stat.icmpProbeError:
         self.icmpProbeError = stat.icmpProbeError

class HttpStat( Model ):
   httpResponseTime = Float( help="HTTP response time in milliseconds to reach "
                                  "the host", optional=True )
   httpProbeError = Str( help="HTTP probes error message", optional=True )

   def fromTacc( self, stat ):
      if stat.httpResponseTime < sys.float_info.max:
         self.httpResponseTime = round( stat.httpResponseTime, 3 )
      if stat.httpProbeError:
         self.httpProbeError = stat.httpProbeError

class TcpStat( Model ):
   latency = Float( help="Network latency in milliseconds to reach the host using " +
                         "the TCP protocol", optional=True )
   tcpProbeError = Str( help="TCP probes error message", optional=True )

   def fromTacc( self, stat ):
      if stat.latency < sys.float_info.max:
         self.latency = round( stat.latency, 3 )
      if stat.tcpProbeError:
         self.tcpProbeError = stat.tcpProbeError

class Stats( Model ):
   sequentialFailedProbeCount = Int( help="Number of sequential probes that did not "
                                          "receive any response" )
   hostUnreachable = Bool( help="Has the sequential failed probe count crossed the "
                                 "threshold" )
   lastResponseAt = Float( help="Timestamp when the last response was received" )
   icmpDestinations = Dict( help="A mapping between ip address and the "
                            "ICMP statistics generated using it",
                            keyType=IpGenericAddress, valueType=IcmpStat )
   httpDestinations = Dict( help="A mapping between url and the HTTP "
                                 "statistics generated using it", keyType=str,
                                 valueType=HttpStat )
   tcpDestinations = Dict( help="A mapping between ip address and the TCP "
                                 "statistics generated using it",
                                 keyType=str, valueType=TcpStat )

class ConnMonitorHost( Model ):
   __revision__ = 3
   description = Str( help="A brief description of the host" )
   ipAddr = IpGenericAddress( help="IP address of the host" )
   hostname = Str( help="Hostname destination for ICMP probes" )
   url = Str( help="URL of the host" )
   pingSize = Int( help="Size of ping packet in bytes" )
   icmpDscp = Int( help="DSCP marking for ICMP echo packets" )
   pingCount = Int( help="Number of ICMP pings" )
   icmpEnabled = Bool( help="ICMP probes are enabled" )
   useDport = Bool( help="TCP port is enabled" )
   configError = Str( help="Error string to indicate if there is a config error "
                      "related to this host", optional=True )
   interfaces = Dict( help="A mapping between source interface and the statistics "
                           "generated using it", keyType=str, valueType=Stats )

   def fromTacc( self, hostConfig ):
      self.description = hostConfig.description
      if not hostConfig.icmpConfig.ipAddr.isAddrZero:
         self.ipAddr = hostConfig.icmpConfig.ipAddr.stringValue
      self.hostname = hostConfig.icmpConfig.hostname
      self.pingSize = hostConfig.icmpConfig.pingSize
      self.icmpDscp = hostConfig.icmpConfig.icmpDscp
      self.pingCount = hostConfig.icmpConfig.pingCount
      self.icmpEnabled = hostConfig.icmpConfig.icmpEnabled
      self.useDport = hostConfig.tcpConfig.useDport
      self.url = hostConfig.httpConfig.url

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         # remove ICMP hostname endpoint statistics
         del dictRepr[ 'hostname' ]
         if 'configError' in dictRepr:
            del dictRepr[ 'configError' ]
         # old revision used only ipv4
         ip = IpGenAddr( dictRepr[ 'ipAddr' ] )
         if ip.af != AddressFamily.ipv4:
            dictRepr[ 'ipAddr' ] = '0.0.0.0'
      elif revision == 2:
         # revision 2 used only ip address for ICMP endpoint
         if dictRepr[ 'hostname' ]:
            if dictRepr[ 'url' ]:
               dictRepr[ 'ipAddr' ] = None
               del dictRepr[ 'hostname' ]
               if 'configError' in dictRepr:
                  del dictRepr[ 'configError' ]
            else:
               return {}

      return dictRepr

def renderHosts( hosts, lossThreshold ):

   class ProbeType( Enum ):
      ICMP = 1
      HTTP = 2
      TCP = 3

   def createTable( headings, probeType, lossThreshold ):
      table = TableOutput.createTable( headings )
      # Headings formatter
      f1 = TableOutput.Format( isHeading=True, justify='left', border=True )
      f1.padLimitIs( True )
      # Text value formatter
      f2 = TableOutput.Format( justify='left' )
      f2.padLimitIs( True )
      # Numbers formatter
      f3 = TableOutput.Format( justify='right' )
      f3.padLimitIs( True )
      # Float with units formatter
      f4 = TableOutput.Format( justify='right-float' )
      f4.padLimitIs( True )
      #table.startRow( f1 )
      table.formatRows( f1 )
      if probeType == ProbeType.ICMP:
         if lossThreshold:
            table.formatColumns( f2, f2, f4, f3, f3, f2 )
         else:
            table.formatColumns( f2, f2, f4, f4, f3, f2 )
      elif probeType == ProbeType.HTTP:
         table.formatColumns( f2, f3, f2 )
      else:
         # TCP
         table.formatColumns( f2, f2, f4, f2 )
      return table

   icmpHeadings = ( 'IP Address', 'Local Interface', 'Latency',
                  'Jitter', 'Packet Loss', 'Probe Error' )
   if lossThreshold:
      icmpHeadings = ( 'IP Address', 'Local Interface', 'Latency',
                     'Lost Probe Count', 'Last Response', 'Status' )
   httpHeadings = ( 'Local Interface', 'Response Time', 'Probe Error' )
   tcpHeadings = ( 'IP Address', 'Local Interface', 'Latency', 'Probe Error' )
   for name, host in sorted( hosts.items() ):
      print( 'Host:', name )
      if host.icmpEnabled:
         print( 'Payload size:', str( host.pingSize ) )
         if host.pingCount != 0:
            print( 'ICMP ping count:', host.pingCount )
         if host.icmpDscp != DscpVal.invalid:
            icmpDscpName = AclCliLib.dscpNameFromValue( host.icmpDscp )
            # Print DSCP ACL name if there is a match
            if icmpDscpName.isnumeric():
               print( 'DSCP value:', icmpDscpName )
            else:
               print( 'DSCP value:', icmpDscpName, f'({host.icmpDscp:06b})' )
      if host.icmpEnabled or host.useDport:
         if host.hostname:
            print( 'Hostname:', host.hostname )
            if host.configError:
               print( 'Config error:', host.configError )
      if host.description:
         print( 'Description:', host.description )

      icmpTable = createTable( icmpHeadings, ProbeType.ICMP, lossThreshold )
      httpTable = createTable( httpHeadings, ProbeType.HTTP, lossThreshold )
      tcpTable = createTable( tcpHeadings, ProbeType.TCP, lossThreshold )

      for intf, stats in sorted( host.interfaces.items() ):
         intf = intf if intf != 'default' else 'none'
         for ip, icmpStat in stats.icmpDestinations.items():
            latencyStr = 'n/a' if icmpStat.latency is None \
                         else f'{icmpStat.latency} ms'
            jitterStr = 'n/a' if icmpStat.jitter is None \
                        else f'{icmpStat.jitter} ms'
            probeErr = icmpStat.icmpProbeError if icmpStat.icmpProbeError else 'n/a'
            if lossThreshold:
               timeStr = ''
               if stats.lastResponseAt:
                  timeSecs = int( Tac.utcNow() - stats.lastResponseAt )
                  timeStr = formatTimeInterval( timeSecs ) + ' ago'
               statusStr = 'unreachable' if stats.hostUnreachable else 'reachable'
               icmpTable.newRow( ip, intf, latencyStr,
                                 str( stats.sequentialFailedProbeCount ),
                                 timeStr, statusStr )
            else:
               icmpTable.newRow( ip, intf, latencyStr, jitterStr,
                                 f'{icmpStat.packetLoss}%', probeErr )
         for httpStat in stats.httpDestinations.values():
            if httpStat.httpResponseTime is None or httpStat.httpResponseTime == 0:
               rt = 'no response'
            else:
               rt = str( httpStat.httpResponseTime ) + ' ms'
            probeErr = httpStat.httpProbeError if httpStat.httpProbeError else 'n/a'
            httpTable.newRow( intf, rt, probeErr )

         for ip, tcpStat in stats.tcpDestinations.items():
            if tcpStat.latency is None or tcpStat.latency == 0:
               latency = 'n/a'
            else:
               latency = f'{tcpStat.latency} ms'
            probeErr = tcpStat.tcpProbeError if tcpStat.tcpProbeError else 'n/a'
            tcpTable.newRow( ip, intf, latency, probeErr )

      hostIpOrHostnameConfigured = ( not host.configError
                                   if host.hostname else host.ipAddr )

      if host.icmpEnabled and hostIpOrHostnameConfigured:
         print( 'Network statistics:' )
         print( icmpTable.output() )
      if host.url:
         print( 'HTTP statistics:' )
         print( host.url )
         print( httpTable.output() )
      if host.useDport and hostIpOrHostnameConfigured:
         print( 'TCP Statistics:' )
         print( tcpTable.output() )

class ConnMonitorClient( Model ):
   hosts = Dict( help="A mapping between a host name and its probe statistics",
                 keyType=str, valueType=ConnMonitorHost )
   _probeLossThreshold = Int( help="Number of sequential lost probes after which a "
                                   "host is considered unreachable" )

   def render( self ):
      if self._probeLossThreshold:
         print( 'Loss threshold:', self._probeLossThreshold )
      renderHosts( self.hosts, self._probeLossThreshold )


class ConnMonitorVrf( Model ):
   description = Str( help="A brief description of the VRF" )
   hosts = Dict( help="A mapping between a host name and its probe statistics",
                 keyType=str, valueType=ConnMonitorHost )
   clients = Dict( help="Connectivity monitor clients keyed by name",
                 keyType=str, valueType=ConnMonitorClient, optional=True )
   _probeLossThreshold = Int( help="Number of sequential lost probes after which a "
                                   "host is considered unreachable" )

   def render( self ):
      if self._probeLossThreshold:
         print( 'Loss threshold:', self._probeLossThreshold )
      if self.description:
         print( 'Description:', self.description )
      renderHosts( self.hosts, self._probeLossThreshold )
      for clientName, client in sorted( self.clients.items() ):
         print( 'Client:', clientName )
         client.render()

class ConnMonitorVrfs( Model ):
   __revision__ = 3
   vrfs = Dict( help="A mapping between VRF and the hosts which are probed "
                "in it", keyType=str, valueType=ConnMonitorVrf )

   def render( self ):
      if self.vrfs:
         print()
         for vrfName, vrf in sorted( self.vrfs.items() ):
            if vrf.hosts or vrf.clients:
               print( 'VRF:', vrfName )
               vrf.render()

   def handleHostsForRev2( self, hostInfo ):
      ipAddr = hostInfo[ "ipAddr" ]
      for intfInfo in hostInfo[ "interfaces" ].values():
         icmpDestinations = intfInfo[ "icmpDestinations" ]
         httpDestinations = intfInfo[ "httpDestinations" ]
         ipAddrs = [ ip for ip in icmpDestinations if ip != ipAddr ]
         for ip in ipAddrs:
            del icmpDestinations[ ip ]
         latency = 0.0
         jitter = 0.0
         packetLoss = 0
         httpResponseTime = 0.0
         for icmpStat in icmpDestinations.values():
            latency = icmpStat[ 'latency' ] if "latency" in icmpStat \
                  else round( sys.float_info.max, 3 )
            jitter = icmpStat[ 'jitter' ] if "jitter" in icmpStat \
                  else round( sys.float_info.max, 3 )
            packetLoss = icmpStat[ 'packetLoss' ]
         del intfInfo[ "icmpDestinations" ]

         for httpStat in httpDestinations.values():
            httpResponseTime = httpStat[ "httpResponseTime" ] if \
                  "httpResponseTime" in httpStat else \
                  round( sys.float_info.max, 3 )
         del intfInfo[ "httpDestinations" ]

         intfInfo[ 'latency' ] = latency
         intfInfo[ 'jitter' ] = jitter
         intfInfo[ 'packetLoss' ] = packetLoss
         intfInfo[ 'httpResponseTime' ] = httpResponseTime

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         ret = { 'hosts': {} }
         if 'default' not in dictRepr[ 'vrfs' ]:
            return ret
         ret[ 'hosts' ] = dictRepr[ 'vrfs' ][ 'default' ][ 'hosts' ]
         for host in list( ret[ 'hosts' ] ):
            info = ret[ 'hosts' ][ host ]
            interfaces = info[ 'interfaces' ]
            if 'default' not in interfaces:
               del ret[ 'hosts' ][ host ]
               continue
            defaultStats = interfaces[ 'default' ]
            ipAddr = info[ 'ipAddr' ]
            url = info[ 'url' ]
            info[ 'latency' ] = 0.0
            info[ 'jitter' ] = 0.0
            info[ 'packetLoss' ] = 0.0
            info[ 'httpResponseTime' ] = 0
            if 'icmpDestinations' in defaultStats and \
               ipAddr in defaultStats[ 'icmpDestinations' ]:
               icmpStat = defaultStats[ 'icmpDestinations' ][ ipAddr ]
               if 'configError' not in icmpStat:
                  if 'latency' in icmpStat:
                     info[ 'latency' ] = icmpStat[ 'latency' ]
                  if 'jitter' in icmpStat:
                     info[ 'jitter' ] = icmpStat[ 'jitter' ]
                  if 'packetLoss' in icmpStat:
                     info[ 'packetLoss' ] = icmpStat[ 'packetLoss' ]
            if 'httpDestinations' in defaultStats and \
               url in defaultStats[ 'httpDestinations' ]:
               httpStat = defaultStats[ 'httpDestinations' ][ url ]
               if 'configError' not in httpStat and 'httpResponseTime' in httpStat:
                  info[ 'httpResponseTime' ] = httpStat[ 'httpResponseTime' ]
            info[ 'hostName' ] = host
            del info[ 'interfaces' ]
         return ret
      elif revision == 2:
         for vrfInfo in dictRepr[ "vrfs" ].values():
            hosts = list( vrfInfo[ "hosts" ] )
            clients = []
            if "clients" in vrfInfo:
               clients = list( vrfInfo[ "clients" ] )
            for host in hosts:
               if not vrfInfo[ "hosts" ][ host ]:
                  del vrfInfo[ "hosts" ][ host ]
            for client in clients:
               clientInfo = vrfInfo[ "clients" ][ client ]
               hosts = list( clientInfo[ "hosts" ] )
               for host in hosts:
                  if not clientInfo[ "hosts" ][ host ]:
                     del clientInfo[ "hosts" ][ host ]
               if not clientInfo[ "hosts" ]:
                  del clientInfo
            for hostInfo in vrfInfo[ "hosts" ].values():
               self.handleHostsForRev2( hostInfo )
            if "clients" in vrfInfo:
               for client in vrfInfo[ "clients" ].values():
                  for hostInfo in client[ "hosts" ].values():
                     self.handleHostsForRev2( hostInfo )

            if not vrfInfo[ "hosts" ] and ( "clients" not in vrfInfo or
                                            not vrfInfo[ "clients" ] ):
               dictRepr[ "vrfs" ] = {}

      return dictRepr
