#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

"""
This program periodically sends out the following counters to a specified gRPC
server:
(1) 'interface-drop': ingress and egress interface drop packets
(2) 'egress-queue-drop': egress interface per-queue drop packets/bytes
(3) 'priority-flow-control': PFC send and receive packets
(4) 'buffer-usage': total buffer usage

Options:
  --server-ip=<serverIp>: server IP v4 or v6 address
  --server-port=<serverPort>: server port
  --interval=<seconds>: reporting period. Default is 60 seconds
  --source-ip=<sourceIp>: source IP used for dial out
  --data-type=<counterType>: subscribe counter type, chosen from
       'interface-drop', 'egress-queue-drop', 'priority-flow-control', 'buffer-usage'
  --dscp=<dscpValue>: DSCP value of GRPC packets
  --stdout: just write counter payload to stdout instead of a server
"""
import optparse # pylint: disable=deprecated-module
import time
import os
import pprint
import json
import socket
import grpc
import Cell
import EntityManager
import LauncherLib
import Plugins
import PyClient
import Tac
import Tracing
import GrpcCounterCollector
import Counter_pb2
import Counter_pb2_grpc
import sys

from GrpcCounterLib import log
from IptablesHelper import RT

trace1 = Tracing.trace1
trace2 = Tracing.trace2

DEFAULT_REPORT_INTERVAL = 60
IPTABLES_FILE_NAME_PREFIX = '/tmp/counterGrpcDialOutIptables'

# options that can be specified in the "exec" command line.
class Config:
   OPTION_SERVER_IP = 'server-ip'
   OPTION_SERVER_PORT = 'server-port'
   OPTION_SOURCE = 'source-ip'
   OPTION_DATA = 'data-type'
   OPTION_INTERVAL = 'interval'
   OPTION_DSCP = 'dscp'
   OPTION_STDOUT = 'stdout'

class ProviderRegistration:
   def __init__( self ):
      self.providerClasses_ = {}    # command option -> provider class
      self.counterTypes_ = set()    # make sure no duplicate counter types

   def register( self, providerClass ):
      counterType = providerClass.COUNTER_TYPE
      commandOption = providerClass.COMMAND_OPT
      trace1( "provider", providerClass, "registered with type", counterType,
              "command option", commandOption )
      assert isinstance( counterType, str )
      assert isinstance( commandOption, str )
      assert commandOption not in self.providerClasses_
      assert counterType not in self.counterTypes_
      self.providerClasses_[ commandOption ] = providerClass
      self.counterTypes_.add( counterType )

   def providerClass( self, commandOption ):
      return self.providerClasses_.get( commandOption )

   def commandOptions( self ):
      return [ k for k, p in self.providerClasses_.items()
               if not p.HIDDEN_COMMAND_OPT ]

class GrpcClient:
   """ GRPC client formats the message to the specific server """
   def __init__( self, cfgServerIp, cfgServerPort, cfgSourceIp, systemStatus,
                 entityMibStatus ):
      trace1( 'Initialize GrpcClient' )
      self.serverIp = None
      self.serverPort = None
      self.sourceIp = None
      self.systemStatus = systemStatus
      self.entityMibStatus = entityMibStatus

      self.pendingResponse = []
      self.maxPendingResponses = 10

      # determine the address family from the server configuration
      self.addrFamily = socket.AF_UNSPEC
      try:
         # getaddrinfo returns a list of 5-tuples
         # (family, socktype, proto, canonname, sockaddr)
         # sockaddr = (address, port) 2-tuple for AF_INET
         #            (address, port, flow info, scope id) 4-tuple for AF_INET6
         serverAddrInfo = socket.getaddrinfo( cfgServerIp, cfgServerPort )
         self.addrFamily = serverAddrInfo[ 0 ][ 0 ]
         self.serverIp = serverAddrInfo[ 0 ][ -1 ][ 0 ]
         self.serverPort = serverAddrInfo[ 0 ][ -1 ][ 1 ]
      except socket.gaierror:
         return

      # validate the client source IP address if configured. It has to be in the
      # same address family as the server address.
      if cfgSourceIp:
         try:
            socket.inet_pton( self.addrFamily, cfgSourceIp )
            self.sourceIp = cfgSourceIp
         except ( OSError, AttributeError, TypeError ):
            pass

      if self.addrFamily == socket.AF_INET:
         self.serverIpAndPort = f'{self.serverIp}:{self.serverPort}'
      elif self.addrFamily == socket.AF_INET6:
         self.serverIpAndPort = f'[{self.serverIp}]:{self.serverPort}'

      self.channel = grpc.insecure_channel( self.serverIpAndPort )
      self.client = Counter_pb2_grpc.AristaCounterStub( self.channel )

   def getDeviceName( self ):
      return self.systemStatus.hostname

   def getDeviceModel( self ):
      return self.entityMibStatus.root.modelName

   def sendCounters( self, counters, timestamp=None ):
      if not self.client:
         return 'Unable to connect to Server: %s' % self.serverIpAndPort

      returnStr = ''
      # We asynchronously call SendCounter. Now check to see if SendCounter
      # has finished. Set agent status if any error.
      finishedResponses = 0
      for ( resp, ts ) in self.pendingResponse:
         if resp.running():
            break
         finishedResponses += 1
         if resp.exception() and not returnStr:
            log( 'Unable to get response from gRPC server %s: %s at %s'
                 % ( self.serverIpAndPort, resp.details(), ts ) )
            returnStr = f'{resp.code()} at {ts} to {self.serverIpAndPort}'
      if finishedResponses:
         trace2( 'Remove', finishedResponses, 'pending SendCounters' )
         self.pendingResponse = self.pendingResponse[ finishedResponses : ]

      msg = Counter_pb2.DeviceCounter()
      deviceInfo = Counter_pb2.Device( vendorName='Arista',
                                       deviceName=self.getDeviceName(),
                                       deviceModel=self.getDeviceModel() )
      # pylint: disable-msg=E1101
      msg.deviceInfo.CopyFrom( deviceInfo )
      timestamp = timestamp or time.time()
      msg.timestamp = format( timestamp, '.3f' )
      # The default limit on grpc message size is 4MB. For a large system, we might
      # need to break the json content into multiple messages. For now, however, this
      # script is only used on fixed systems. We will worry about the problem later.
      totalLen = 0
      for counterType, counter in counters.items():
         counterMsg = msg.counter.add()
         counterMsg.counterType = counterType
         counterMsg.jsonContent = json.dumps( counter, ensure_ascii=False ).encode()
         totalLen += len( counterMsg.jsonContent )
      # asynchronous call
      trace2( 'Sending counters %s', list( counters ), 'length:', totalLen )
      counterFuture = self.client.SendCounter.future( msg )
      if counterFuture.running() and (
            len( self.pendingResponse ) < self.maxPendingResponses ):
         self.pendingResponse.append( ( counterFuture, str( msg.timestamp ) ) )
         trace2( 'Add pending SendCounter:', str( msg.timestamp ),
                 'total:', len( self.pendingResponse ) )

      return returnStr

class CounterDialOutAgent:
   """
   CounterDialOutAgent handles user configuration and instantiates
   DataCollector and GrpcClient. The method run() periodically
   retrieves json formatted counters from DataCollector and
   sends out to the server via gRPC by GrpcClient.
   """
   def __init__( self, sysname ):
      self.optionToHandler = {
         Config.OPTION_SERVER_IP : self.handleServerIp,
         Config.OPTION_SERVER_PORT : self.handleServerPort,
         Config.OPTION_SOURCE : self.handleClientSource,
         Config.OPTION_INTERVAL : self.handleInterval,
         Config.OPTION_DATA : self.handleCounterTypeConfig,
         Config.OPTION_DSCP : self.handleDscp,
         Config.OPTION_STDOUT : self.handleStdout
      }

      self.providerReg_ = ProviderRegistration()
      Plugins.loadPlugins( 'GrpcCounterPlugin', self.providerReg_ )

      # parse command line
      self.cmdOptions, _ = self.parseCmdOptions()
      debug = self.getCmdOptionValue( 'debug' )
      if debug:
         Tracing.traceSettingIs( debug )

      # store configuration
      self.serverIpPorts = None
      self.defaultServerPort = None
      self.clientSourceIp = None
      self.filteredCounterProviders = None
      self.reportInterval = DEFAULT_REPORT_INTERVAL
      self.dscp = None

      self.initialized = False
      self.dataCollector = None
      self.grpcClients = []
      self.activeClientIndex = None
      self.sysname = sysname
      self.entityMgr = None
      self.sysdb = None
      self.agentConfig = None
      self.agentStatus = None
      self.daemonName = None
      self.stdout = False
      self.iptablesFileName = IPTABLES_FILE_NAME_PREFIX
      self.systemStatus = None
      self.entityMibStatus = None

      self.on_initialized()

   def parseCmdOptions( self ):
      # parse the configuration provided in the 'exec' command line
      parser = optparse.OptionParser()
      parser.add_option( '-d', '--server-ip', action='append',
                         help='server IP v4 or v6 address. Append '
                         'port number if it differs from the default' )
      parser.add_option( '-p', '--server-port', action='store',
                         help='default server port' )
      parser.add_option( '-i', '--interval', action='store', type=int,
                         default=DEFAULT_REPORT_INTERVAL,
                         help='reporting interval (seconds)' )
      parser.add_option( '-s', '--source-ip', action='store',
                         help='client source IP address' )
      counterOpts = ','.join( sorted( self.providerReg_.commandOptions() ) )
      parser.add_option( '-t', '--data-type', action='store',
                         help='data type (%s)' % counterOpts )
      parser.add_option( '--dscp', action='store',
                         help='dscp value of gRPC packets' )
      parser.add_option( '--debug', action='store',
                         help='Enable tracing' )
      parser.add_option( '--stdout', action='store_true',
                         help='Print payload to stdout instead of a server' )
      return parser.parse_args()

   def getCmdOptionValue( self, opt ):
      # option name stored in cmdOptions uses "_" instead of "-", e.g. option for
      # "server-ip" corresponds to attribute "server_ip"
      cmdOpt = opt.replace( '-', '_' )
      return getattr( self.cmdOptions, cmdOpt, None )

   def setAgentStatus( self, stateName, state ):
      trace1( 'Set agent status', stateName, 'to', state )
      if not self.agentStatus:
         return
      self.agentStatus.data[ stateName ] = state

   def setAgentRunning( self, running=True ):
      if self.agentStatus:
         self.agentStatus.running = running

   def on_initialized( self ):
      trace1( 'on_initialized' )
      self.entityMgr = EntityManager.Sysdb( self.sysname )
      self.sysdb = PyClient.PyClient( self.sysname, 'Sysdb' ).agentRoot()

      self.setDaemonName()
      # clean up the old ip table rules and old agent status.
      self.cleanup()

      self.setAgentStatus( 'State', 'Initializing' )
      for opt, optHandler in self.optionToHandler.items():
         optVal = self.getCmdOptionValue( opt )
         optHandler( optVal )

      if ( not self.stdout and not self.serverIpPorts or
           not self.filteredCounterProviders ):
         self.setAgentStatus( 'State', 'Missing Config' )
         return

      # mount sys/net/status to get host name, hardware/entmib to get model name
      mg = self.entityMgr.mountGroup()
      self.systemStatus = mg.mount( Cell.path( 'sys/net/status' ),
                                    'System::NetStatus', 'r' )
      self.entityMibStatus = mg.mount( 'hardware/entmib',
                                       'EntityMib::Status', 'r' )
      mg.close( blocking=True )

      self.dataCollector = GrpcCounterCollector.DataCollector(
         self.entityMgr, self.sysdb,
         self.filteredCounterProviders )
      self.dataCollector.doInit()

      if not self.stdout:
         # User is allowed to configure multiple servers in any of the following
         # forms:
         # (1) -d 1.1.1.1 -d 1.1.1.2 -p 12345
         # (2) -d 1.1.1.1:12345 -d 1.1.1.2:12345
         # (3) -d 1.1.1.1 -d 1.1.1.2:12345 -p 12345
         # The serverIpPorts might or might not contain the port number. We need to
         # extract the ip and the port information from the -d option. If port is
         # not provided in -d, we will use the port configuration in -p.
         for i, ipPort in enumerate( self.serverIpPorts ):
            sip = ipPort
            sport = self.defaultServerPort
            if ipPort.startswith( '[' ) and ']:' in ipPort:
               # [a:b::c]:p specifies ipv6 address and port
               sip, _, sport = ipPort.rpartition( ']:' )
               sip = sip.strip( '[]' )
            elif '.' in ipPort and ':' in ipPort:
               # a.b.c.d:p specifies ipv4 address and port
               sip, _, sport = ipPort.rpartition( ':' )

            grpcClient = GrpcClient( sip, sport, self.clientSourceIp,
                                     self.systemStatus, self.entityMibStatus )

            sourceStr = ''
            if grpcClient.sourceIp:
               sourceStr = 'Source: %s' % grpcClient.sourceIp
            self.setAgentStatus(
               '%s Server' % ( 'Primary' if i == 0 else 'Secondary' ),
               'IP: {} Port: {} {}'.format( grpcClient.serverIp,
                  grpcClient.serverPort,
                                        sourceStr ) )

            if not grpcClient.serverIp or not grpcClient.serverPort:
               self.setAgentStatus( 'State', 'Invalid Server Config: %s' % ipPort )
               return

            if not grpcClient.client:
               self.setAgentStatus( 'State', 'Error Connecting to Server: %s:%s' %
                                    ( grpcClient.serverIp, grpcClient.serverPort ) )

            self.grpcClients.append( grpcClient )

         self.activeClientIndex = 0
         self.addIpTableRule()

      Tac.runActivities( 1 )
      self.initialized = True
      self.setAgentStatus( 'State', 'Initialized' )

   def unsupportedOption( self, dummyVal ):
      log( 'Unsupported option' )

   def setDaemonName( self ):
      agentConfigCliDir = self.sysdb.entity[ LauncherLib.agentConfigCliDirPath ]
      for agentName, agentConfig in agentConfigCliDir.agent.items():
         if agentConfig.exe.find( 'CounterGrpcDialOut' ) > 0:
            if set( sys.argv[ 1 : ] ) == {
                   s for s in agentConfig.argv.values() if s }:
               trace1( "set daemon name:", agentName )
               self.daemonName = agentName
               break

      if self.daemonName:
         self.iptablesFileName = IPTABLES_FILE_NAME_PREFIX + '-' + self.daemonName
         agentConfigDir = self.sysdb.entity[ 'daemon/agent/config' ]
         agentStatusDir = self.sysdb.entity[ 'daemon/agent/status' ]
         self.agentConfig = agentConfigDir[ self.daemonName ]
         if self.daemonName not in agentStatusDir:
            self.agentStatus = agentStatusDir.newEntity(
               "GenericAgent::Status", self.daemonName )
         else:
            self.agentStatus = agentStatusDir[ self.daemonName ]

   def handleInterval( self, interval ):
      trace1( 'handleInterval:', interval )
      try:
         self.reportInterval = float( interval )
      except ( TypeError, ValueError ):
         log( 'Invalid interval input: %s' % interval )
      self.setAgentStatus( 'Interval', str( self.reportInterval ) )

   def handleCounterTypeConfig( self, userInput ):
      trace1( 'handleCounterTypeConfig:', userInput )

      self.filteredCounterProviders = []
      if userInput:
         # Based on user's input to "option data", determine which counters
         # are subscribed
         userInputList = userInput.split( ',' )
         for i in userInputList:
            provider = self.providerReg_.providerClass( i )
            if not provider:
               log( 'Unsupported counter option %s' % i )
            elif provider in self.filteredCounterProviders:
               trace1( 'Duplicate counter option', i, 'ignored' )
            else:
               trace1( 'Counter type', provider.COUNTER_TYPE, 'is subscribed' )
               self.filteredCounterProviders.append( provider )

      self.setAgentStatus( 'Subscribed data',
                           ','.join( x.COUNTER_TYPE for x
                                     in self.filteredCounterProviders ) )

   def handleServerIp( self, newServerIpPorts ):
      trace1( 'handleServerIp:', newServerIpPorts )
      self.serverIpPorts = newServerIpPorts

   def handleServerPort( self, newServerPort ):
      trace1( 'handleServerPort:', newServerPort )
      self.defaultServerPort = newServerPort

   def handleClientSource( self, newClientSourceIp ):
      trace1( 'handleClientSource:', newClientSourceIp )
      self.clientSourceIp = newClientSourceIp

   def handleDscp( self, newDscp ):
      trace1( 'handleDscp:', newDscp )
      self.dscp = newDscp
      if self.dscp:
         self.setAgentStatus( 'DSCP', self.dscp )

   def handleStdout( self, newStdout ):
      trace1( 'handleStdout:', newStdout )
      self.stdout = newStdout

   def deleteIpTableRule( self ):
      if os.path.isfile( self.iptablesFileName ):
         trace1( 'deleting existing rule' )
         with open( self.iptablesFileName ) as f:
            for iptableRule in f:
               os.system( iptableRule.replace( '-A', '-D' ) )

         # remove the file if we have deleted the ip table rule
         os.remove( self.iptablesFileName )

   def addIpTableRule( self ):
      snatRule = []
      dscpRule = []

      for client in self.grpcClients:
         if not client or not client.serverIp or not client.serverPort:
            continue

         cmd = 'ip6tables' if client.addrFamily == socket.AF_INET6 else 'iptables'
         if client.sourceIp:
            trace1( 'adding new rule for source', client.sourceIp )
            # Add a source NAT rule to change the source IP of gRPC packets
            snatRule.append( RT( 'sudo %s -t nat -A {{postroutingchain}} -p tcp '
                                 '-d %s --dport %s -j SNAT --to %s' ) %
                             ( cmd, client.serverIp, client.serverPort,
                               client.sourceIp ) )

         if self.dscp:
            trace1( 'adding new rule for setting dscp', self.dscp )
            # Add rule to set the dscp value of gRPC packets
            dscpRule.append( RT( 'sudo %s -t mangle -A {{outputchain}} -p tcp '
                                 '-d %s --dport %s -j DSCP --set-dscp %s' ) %
                             ( cmd, client.serverIp, client.serverPort,
                               self.dscp ) )

      if not snatRule and not dscpRule:
         return

      # When daemon is shutdown, we don't get a chance to cleanup the iptable
      # rule that we added for source nat. Store the rule in a tmp file so that
      # this rule can be cleaned up when the daemon is no shut.
      allRules = snatRule + dscpRule
      with open( self.iptablesFileName, 'w' ) as f:
         f.write( '\n'.join( allRules ) )
      for rule in allRules:
         os.system( rule )

   def cleanup( self ):
      self.deleteIpTableRule()
      if self.agentStatus:
         self.agentStatus.data.clear()
         self.agentStatus.running = False

   def run( self ):
      counters = {}
      iterNum = 0
      baseTime = time.time()
      self.setAgentStatus( 'State', 'Active' )
      self.setAgentRunning( True )
      while self.initialized:
         iterStart = time.time()
         Tac.runActivities( 0 )
         if self.agentConfig and not self.agentConfig.enabled:
            trace1( "Cleaning up and exiting" )
            self.cleanup()
            return

         counters = self.dataCollector.getCounters()
         if counters and self.stdout:
            pprint.pprint( counters )
         elif counters:
            errMsg = ''
            for _ in range( len( self.grpcClients ) ):
               errMsg = self.grpcClients[ self.activeClientIndex ].sendCounters(
                  counters, iterStart )
               if errMsg:
                  trace1( 'Server', self.activeClientIndex, 'is down' )
                  self.activeClientIndex += 1
                  if self.activeClientIndex == len( self.grpcClients ):
                     self.activeClientIndex = 0
                  trace1( 'Try next server', self.activeClientIndex )
                  self.setAgentStatus( 'Last Message Error', errMsg )
               else:
                  break

         iterNum += 1
         iterEnd = time.time()
         if iterEnd - iterStart < self.reportInterval:
            idleTime = baseTime + iterNum * self.reportInterval - iterEnd
            if idleTime > 0:
               Tac.runActivities( idleTime )
            else:
               trace1( 'Catch up at', iterStart )
         else:
            trace1( 'Run overtime at', iterStart )
