#!/usr/bin/env python3
#
# Copyright (c) 2015-2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import time
import threading
import traceback
from copy import deepcopy
from MssPolicyMonitor import Lib
from MssPolicyMonitor.Error import ( FirewallError, FirewallAPIError,
                                     FirewallConfigError )
from MssPolicyMonitor.Lib import t0, t1, t2, t3, t4, t6, b0, v, BYPASS_MODE
from MssPolicyMonitor.PluginLib import ServiceDevice, IPolicyPlugin, IHAStatePlugin


ENABLE_PROFILING = False
TIMEOUT_JOIN_FACTOR = 20
INTERVAL_JOIN_FACTOR = 2
threadCount = 0


def getThreadId():
   global threadCount
   threadCount += 1
   return threadCount


def genMonitoringThreadName( deviceSetName, deviceId, aggMgrId='' ):
   name = f't#{getThreadId()}:{deviceSetName}:{deviceId}'
   return f'{name}:via:{aggMgrId}' if aggMgrId else name


def genDeviceMonitor( serviceDeviceType, deviceConfig, sysdbPolicyMgr ):
   ''' Returns a reference to a DeviceMonitor object to be used for
       starting and stopping the device polling thread.
       Called by MssPolicyMonitorAgent.
   '''
   mpmPlugin = Lib.getPlugin( serviceDeviceType )
   t4('startMonitoringPolicies using plugin:', mpmPlugin )
   if not mpmPlugin:
      return None

   threadName = genMonitoringThreadName( deviceConfig[ 'deviceSet' ],
                                         deviceConfig[ 'ipAddress' ] )
   return DeviceMonitor( mpmPlugin, deviceConfig, sysdbPolicyMgr, threadName )


def startMonitoringPolicies( devMonitor ):
   ''' Starts the policy monitoring thread on the passed DeviceMonitor.
       Called by MssPolicyMonitorAgent.
   '''
   devMonitor.startRunning()

def stopMonitoringPolicies( monitorInstance, waitForCompletion=False ):
   ''' Stops policy monitoring for the passed MPM instance.
       Called by MssPolicyMonitorAgent. 
   '''
   if monitorInstance:
      monitorInstance.stopRunning( waitForCompletion )

####################################################################################
class DeviceMonitor:
   ''' Monitor service device policies and associated links to Arista switch(es)
   '''
   def __init__( self, mpmPlugin, deviceConfig, agentSysdbMgr, instanceThreadName ):
      t2('DeviceMonitor init deviceConfig:', Lib.hidePassword( deviceConfig ) )
      self.mpmAgent = agentSysdbMgr
      self.mpmPlugin = mpmPlugin
      self.deviceSetName = deviceConfig[ 'deviceSet' ]
      self.ipAddr = deviceConfig[ 'ipAddress' ]
      self.deviceType = deviceConfig[ 'serviceDeviceType' ]
      self.policyTags = deviceConfig[ 'policyTags' ]
      self.pollInterval = deviceConfig[ 'queryInterval' ]
      self.timeout = deviceConfig[ 'timeout' ]
      self.deviceConfig = deviceConfig
      self.keepRunning = True
      self.mpmThread = None
      self.instanceThreadName = instanceThreadName
      self.serviceDevices = {}  # service devices to monitor, key = ip or serial#
      self.AggMgrPolicies = {}  # key=service device IP, value=SvcDevicePolicy list
      self.errorLogs = {}
      self.AggMgrPoliciesLock = threading.Lock()
      self.isMonitoringInstanceActive = True  # Used to detect AggrMgr HA Failover 
      self.deviceThreads = {}   # Threads created to monitor map-only devices

   def __str__( self ):
      # pylint: disable-next=consider-using-f-string
      return '<{}>:{}:{}_{}'.format( self.__class__.__name__, self.deviceSetName,
                                     self.ipAddr, self.instanceThreadName )

   def startRunning( self ):
      if self.deviceConfig[ 'isAggregationMgr' ]:
         self.mpmThread = threading.Thread( name=self.instanceThreadName,
                                            target=self.accessDevicesViaAggMgr )
      else:
         self.mpmThread = threading.Thread( name=self.instanceThreadName,
                                            target=self.monitorDevicePolicies,
                                            args=(self.ipAddr,) )
      self.mpmThread.setDaemon( True )
      b0( '++ START monitoring for', v( self.deviceType) , 'thread:',
         v( self.instanceThreadName ) )
      self.mpmThread.start()

   def stopRunning( self, waitForCompletion=False ):
      self.keepRunning = False
      t1('keepRunning:', self.keepRunning, 'for:', self.instanceThreadName )
      if waitForCompletion and self.mpmThread:
         self.mpmThread.join( TIMEOUT_JOIN_FACTOR * self.timeout +
                              INTERVAL_JOIN_FACTOR * self.pollInterval )
         if self.mpmThread.is_alive():
            t0( 'polling for', self.ipAddr, 'has been stopped but',
                self.instanceThreadName,
                'is still running. Aborting MssPolicyMonitor' )
            os.abort()
         self.mpmThread = None

   def monitorDevicePolicies( self, deviceId, isAccessedViaAggrMgr=False ):
      ''' Monitor policies on a service device.
            deviceId may be an IP address, dns name, serial number etc.
      '''
      threadName = threading.current_thread().name
      t1('monitorDevicePolicies for device', deviceId, 'runs on thread', threadName )
      if deviceId not in self.serviceDevices:
         self.serviceDevices[ deviceId ] = ServiceDevice( deviceId, self.deviceType,
                                                          threadName )
      device = self.serviceDevices[ deviceId ]
      device.deviceSetName = self.deviceSetName
      interfaces = None
      routes = None
      if ENABLE_PROFILING:
         import cProfile # pylint: disable=import-outside-toplevel
         profiler = cProfile.Profile()
         profiler.enable()
      while self.keepRunning:
         startTime = time.time()
         try:
            Lib.checkSslProfileStatus( self.deviceConfig, self.mpmAgent.sslStatus )
            # pylint: disable-next=no-else-break
            if not device.isCurrent:  # happens when dev removed from a group
               self.removeNonCurrentDevice( device )
               break
            elif not device.initComplete():
               self.initServiceDevice( deviceId, isAccessedViaAggrMgr )
               if not self.keepRunning:  # many restarts possible on eos config load
                  break
            # For firewalls, this will always be True
            # For AggrMgr's this will indicate if the monitoring instance
            # is active depending on HA configuration
            if not self.isMonitoringInstanceActive:
               self.removeNonCurrentDevice( device, cleanupStatus=False )
               break

            t0('@ polling', device.name, device.mgmtIp, 'thread', threadName )
            if device.isSingleLogicalDeviceHaModel:
               device.setDeviceInfo( device.plugin.getDeviceInfo() )
            if isinstance( device.plugin, IHAStatePlugin ):
               haState = device.plugin.getHighAvailabilityState()
               t2( device.name, haState )
               device.haPeerMgmtIp = haState.getPeerManagementIp()
               isHaPassiveOrSecondary = haState.isHaPassiveOrSecondary()
               self.mpmAgent.updateHaStatus( deviceId, haState )
            else:
               isHaPassiveOrSecondary = False

            if isHaPassiveOrSecondary:
               t1( deviceId, 'in HA Mode but not Active/Primary device, ignoring')
               self.mpmAgent.updatePolicies( {}, [], {}, None, device,
                                             haPrimary=False )
            else:
               if not self.keepRunning:
                  break
               if isinstance( device.plugin, IPolicyPlugin ):  # e.g. PAN, FNET
                  policies = device.plugin.getPolicies( self.policyTags )
               else:  # e.g. CheckPoint
                  with self.AggMgrPoliciesLock:
                     # Wrap policies into a dummy vsys named 'root'
                     # This should be fixed once Checkpoint provides an API to
                     # manage virtual system
                     policies = { 'root' : deepcopy(
                        self.AggMgrPolicies.get( device.mgmtIp, [] ) ) }
                  populateZoneIntfStatus( policies, device.plugin )

               neighbors = device.plugin.getInterfaceNeighbors()
               interfaces = device.plugin.getInterfacesInfo()
               routes = device.plugin.getDeviceRoutingTables()
               if not self.keepRunning:  # check again, can be several seconds later
                  break

               # Error recovery
               self.mpmAgent.firewallLogger.recover( deviceId )
               
               self.mpmAgent.updatePolicies( policies, interfaces, neighbors,
                                             routes, device )
            self.mpmAgent.updateStatus( deviceId, device.threadName )
         except FirewallError as fwError:
            # Log the error if not previously logged
            self.mpmAgent.firewallLogger.log( deviceId, fwError )
            self.handleMonitoringError( device )
         except Exception as ex:  # pylint: disable=W0703
            t0( 'Error:', ex, ' device monitoring will resume on next interval' )
            traceback.print_exc()
            self.handleMonitoringError( device )

         t2( device.name, 'cycle', delta( startTime ), 'threads',
             [ t.name for t in threading.enumerate() ] )
         time.sleep( self.pollInterval )

      b0( '-- monitorDevicePolicies stopping thread:',
            v( threading.current_thread() ) )
      self.closeDeviceApiConnection( device.plugin )
      if ENABLE_PROFILING:
         profiler.disable()
         # pylint: disable-next=consider-using-f-string
         profiler.dump_stats( 'cprofile_%s' % threadName )

   def stopMonitoringMembers( self, aggMgrId, aggMgrThreadName ):
      ''' Stop monitoring the members during HA failover, Active -> Passive
          or when an invalid HA configuration is detected.
          Mark the monitoring instance as inactive. The threads managing the
          firewalls will detect this and terminate if they were created. There
          shouldn't be any change to the policies. Service device list is cleared
          by individual threads.
      '''
      self.isMonitoringInstanceActive = False

      # If no service devices were created, nothing to be done
      if not self.serviceDevices:
         return

      b0( 'AggMgrId', v( aggMgrId ), 'AggMgrThread', v( aggMgrThreadName ),
            'Detected failover/invalid config, stopping monitor threads' )
      
      # Ensure all threads are stopped
      for deviceId, thread in self.deviceThreads.items():
         # Have a long timeout to incorporate the python thread scheduling
         thread.join( ( 20 * self.timeout ) + ( 2 * self.pollInterval ) )
         if thread.is_alive():
            t0( aggMgrId, 'is passive but monitoring thread', thread.name,
                'is still running for', deviceId, ',aborting MssPolicyMonitor' )
            os.abort()

      self.deviceThreads.clear()

   def cleanupMember( self, aggMgrId, aggMgrPlugin, memberId ):
      try:
         self.serviceDevices[ memberId ].isCurrent = False
         # Remove the thread entry
         self.deviceThreads.pop( memberId, None )
         if isinstance( aggMgrPlugin, IPolicyPlugin ):
            with self.AggMgrPoliciesLock:
               self.AggMgrPolicies[ memberId ] = []  # clear policies
      except KeyError:
         pass  # ignore, may have just been removed by device thread
      self.mpmAgent.deleteGroupMember( memberId, aggMgrId )

   def checkInvalidHaConfig( self, haState ):
      # Check if we have an invalid HA configuration
      invalidConfig = False
      try:
         if ( self.mpmAgent.hasAggrMgrPair[ self.deviceSetName ] and
              not haState.isHaEnabled() ):
            # Invalid configuration, there are two aggregation managers
            # but HA is not configured on the current device
            invalidConfig = True
            t0( 'Aggregation manager pair found in deviceSet:',
                self.deviceSetName, 'but high availability not ' +
                'enabled on device:', self.ipAddr )
         else:
            t6( self.ipAddr, 'is',
                'passive' if haState.isHaPassiveOrSecondary() else 'active' )
      except KeyError:
         pass  # Ignore, must have been removed due to deletion of device set
      return invalidConfig

   def accessDevicesViaAggMgr( self ):
      ''' Launches a thread for each service device in an aggregation manager group
      '''
      group = self.deviceConfig[ 'group' ]
      t1('accessDevicesViaAggMgr running on thread:', threading.current_thread(),
         'for device group:', group )
      aggMgrThreadName = threading.current_thread().name
      aggMgrId = self.ipAddr  # IP addr or DNS name
      aggMgrPlugin = None
      mgmtIp = ''
      invalidHaConfig = False
      isHaPassiveOrSecondary = False
      while self.keepRunning:
         try:
            Lib.checkSslProfileStatus( self.deviceConfig, self.mpmAgent.sslStatus )
            if not aggMgrPlugin:
               aggMgrPlugin = self.mpmPlugin.getAggMgrPluginObj( self.deviceConfig )
            if not mgmtIp:
               devInfo = aggMgrPlugin.getDeviceInfo()
               t1('aggMgr deviceInfo:', devInfo )
               if 'ipAddr' not in devInfo:
                  # An API error must have occured for this information to be missing
                  t2( 'deviceInfo is incomplete: ipAddr is missing' )
                  raise FirewallAPIError( 200, None )
               mgmtIp = devInfo[ 'ipAddr' ]

            # HA Support: Query AggrMgr and check if HA is supported
            if isinstance( aggMgrPlugin, IHAStatePlugin ):
               haState = aggMgrPlugin.getHighAvailabilityState()
               isHaPassiveOrSecondary = haState.isHaPassiveOrSecondary()
               invalidHaConfig = self.checkInvalidHaConfig( haState )
               self.mpmAgent.updateHaStatus( aggMgrId, haState )

            if isHaPassiveOrSecondary or invalidHaConfig:
               # When device is passive or when we detect invalid HA configuration
               if self.isMonitoringInstanceActive:
                  self.stopMonitoringMembers( aggMgrId, aggMgrThreadName )
               if invalidHaConfig:
                  self.mpmAgent.updateStatus( aggMgrId, aggMgrThreadName,
                      mgmtIp=mgmtIp )
                  raise FirewallConfigError( 'Aggregation Manager',
                      # pylint: disable-next=consider-using-f-string
                      'Multiple device managers are configured in %s. '
                      'High availability must be enabled on %s.'
                      % ( self.deviceSetName, aggMgrId ) )
            else:
               self.isMonitoringInstanceActive = True
               if isinstance( aggMgrPlugin, IPolicyPlugin ):
                  with self.AggMgrPoliciesLock:
                     t2( '* aggMgr', aggMgrThreadName, 'get policies and ' +
                        'group members' )
                     self.AggMgrPolicies = (
                       aggMgrPlugin.getPolicies( self.policyTags ) )
                     # devices ref'd in rules
                     groupMembers = list( self.AggMgrPolicies )
               else:
                  t3( '* aggMgr', aggMgrThreadName, 'checking device group members' )
                  groupMembers = aggMgrPlugin.getAggMgrGroupMembers( group )

               if not self.keepRunning:  # check again here
                  break
               t1( aggMgrId, 'group:', group, 'currentMembers:', groupMembers )
               previousDevices = set( self.serviceDevices.keys() )
               for memberId in groupMembers:
                  if memberId in self.serviceDevices:
                     previousDevices.discard( memberId )  # remove current id
                  else:
                     self.initMemberDevice( memberId, aggMgrId, aggMgrThreadName )
               
               t3( 'previous members no longer in group:', previousDevices )
               for memberId in previousDevices:
                  self.cleanupMember( aggMgrId, aggMgrPlugin, memberId )

               # Update the thread name in the service device status for member
               # devices. This will ensure that the Active Panorama's firewall
               # threads will be able to update the policies they have read.
               self.mpmAgent.updateMonitoringThreadNames( self.deviceThreads,
                     aggMgrThreadName )

            # Error recovery
            self.mpmAgent.aggrMgrLogger.recover( aggMgrId )

            self.mpmAgent.updateStatus( aggMgrId, aggMgrThreadName, mgmtIp=mgmtIp )
         except FirewallError as fwError:
            # Log the error if not previously logged
            self.mpmAgent.aggrMgrLogger.log( aggMgrId, fwError )
         except Exception as ex:  # pylint: disable=W0703
            t0( 'Error:', ex, ' device monitoring will resume on next interval' )
            traceback.print_exc()

         t3( aggMgrThreadName, 'aggMgr cycle complete')
         time.sleep( self.pollInterval )
      b0( '-- accessDevicesViaAggMgr exiting thread:',
            v( threading.current_thread() ) )
      self.closeDeviceApiConnection( aggMgrPlugin )

   def initMemberDevice( self, memberId, aggMgrId, aggMgrThreadName ):
      memberThreadName = genMonitoringThreadName( self.deviceSetName, memberId,
                                                  aggMgrId )
      self.serviceDevices[ memberId ] = ServiceDevice( memberId, self.deviceType,
                                                       memberThreadName )
      self.mpmAgent.addGroupMember( memberId, memberThreadName,
                                    aggMgrId, aggMgrThreadName )
      t1('+ START monitor thread for group member device: ', memberId )
      devThread = threading.Thread( name=memberThreadName,
                                    target=self.monitorDevicePolicies,
                                    args=( memberId, True ) )
      self.deviceThreads[ memberId ] = devThread
      devThread.setDaemon( True )
      devThread.start()

   def removeNonCurrentDevice( self, device, cleanupStatus=True ):
      deviceId = device.deviceId
      if cleanupStatus:
         t1('serviceDevice', deviceId, 'not current, deleting' )
         self.mpmAgent.cleanupPoliciesForDevice( device )
      else:
         t1('serviceDevice', deviceId, 'policies not cleaned, ' +
            'AggrMgr failover/invalid config' )
      serviceDevice = self.serviceDevices[ deviceId ]
      self.closeDeviceApiConnection( serviceDevice.plugin )
      del self.serviceDevices[ deviceId ]

   def closeDeviceApiConnection( self, plugin ):
      try:
         t4('closing device API connection')
         plugin.closeApiConnection()
      except Exception as ex:  # pylint: disable=W0703
         t0('ignoring error while closing device API connection:', ex )

   def initServiceDevice( self, deviceId, isAccessedViaAggrMgr ):
      t3( '@initSvcDevice id:', deviceId, 'ip:', self.deviceConfig[ 'ipAddress' ],
          'accessViaAggMgr:', isAccessedViaAggrMgr )
      device = self.serviceDevices[ deviceId ]
      if isAccessedViaAggrMgr:
         # Copy configuration from device member and add virtual system
         # into the aggregation manager configuration which is provided to the plugin
         config = self.mpmAgent.getDeviceConfig( self.deviceSetName, deviceId )
         deviceConfig = self.deviceConfig.copy()
         if config:
            deviceConfig[ 'virtualInstance' ] = config[ 'virtualInstance' ]
            deviceConfig[ 'vrouters' ] = config[ 'vrouters' ]
         device.plugin = self.mpmPlugin.getPluginObj( deviceConfig, deviceId )
      else:
         device.plugin = self.mpmPlugin.getPluginObj( self.deviceConfig )
      devInfo = device.plugin.getDeviceInfo()
      t2( 'svcDeviceInfo:', devInfo, 'id:', deviceId, 'ip:',
          self.deviceConfig[ 'ipAddress' ] )
      device.setDeviceInfo( devInfo )

   def handleMonitoringError( self, device ):
      t1('exceptionHandlingMode:', self.deviceConfig[ 'exceptionMode' ] )
      self.mpmAgent.updateStatus( device.deviceId, device.threadName, error=True )
      if self.deviceConfig[ 'exceptionMode' ] == BYPASS_MODE:
         self.mpmAgent.cleanupPoliciesForDevice( device )


def populateZoneIntfStatus( policies, plugin ):
   ''' Populate policy zone interface status with latest service device
       interface status.
   '''
   intfsInfo = plugin.getInterfacesInfo( resolveZoneNames=False )
   for vsys, policyList in policies.items():
      for policy in policyList:
         for polZoneIntfs, selector in [ ( policy.srcZoneInterfaces, 'S' ),
                                         ( policy.dstZoneInterfaces, 'D' ) ]:
            updatedZoneIntfs = []
            for polZoneIntf in polZoneIntfs:
               if polZoneIntf.name in intfsInfo:
                  intfsInfo[ vsys ][ polZoneIntf.name ].zone = polZoneIntf.zone
                  updatedZoneIntfs.append( intfsInfo[ vsys ][ polZoneIntf.name ] )
            if selector == 'S':
               policy.srcZoneInterfaces = updatedZoneIntfs
            elif selector == 'D':
               policy.dstZoneInterfaces = updatedZoneIntfs


def delta( startTime ):
   # pylint: disable-next=consider-using-f-string
   return  'time=%.2fs' % ( time.time() - startTime )
