#!/usr/bin/env python
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from __future__ import absolute_import, division, print_function
import Tac
import time
import sys
import os
import syslog
import Cell
import PyClient
import datetime as dt
from CliPlugin.AsuReloadCli import _asuData

# Force reload the latest AsuPatchBase from swi
# pylint: disable-msg=import-error
# pylint: disable-msg=wrong-import-position
sys.modules.pop( 'AsuPatchBase', None )
import AsuPatchBase

# waitFor and Timeout implementations carried over from ArPyUtils so as to
# not rely on the older versions' implementation of the same
class Timeout( Exception ):
   pass

def waitFor( func, timeout=600.0, description=None,
             checkAfter=0.1, firstCheck=0.1, maxDelay=5 ):
   """
   An alternative to Tac.waitFor(), for packages which cannot depend on tacc.
   - there are no warnings, it raises 'Timeout' exception on timeout
   - instead of maxDelay we have a checkAfter, seconds to elapse before
     evaluating the condition
   - sleep is not an option, we always sleep
   - when func() returns something which evaluates to True, that value will
     be returned
   - we always wait before calling func, first execution of func is after
     checkAfter seconds.
   """
   expiry = time.time() + timeout
   if firstCheck:
      time.sleep( firstCheck )
   checkWait = checkAfter
   while True:
      lastChance = time.time() >= expiry
      result = func()
      if result:
         return result
      elif lastChance:
         raise Timeout( "Timed out waiting for %s" % description )
      timeLeft = max( expiry - time.time(), 0 )
      checkWait = min( checkWait, timeLeft )
      time.sleep( checkWait )
      checkWait = min( maxDelay, checkWait * 2 )

def getDebugLogPath():
   preReloadLogsSaveErrorFile = os.environ.get(
      "PRE_RELOAD_LOGS_ERROR_FILE", "pre_reload_monitor_bgp_socket_logs" )
   preReloadLogsDestDir = os.environ.get(
      "PRE_RELOAD_LOGS_DEST_DIR", "/mnt/flash/debug" )
   preReloadLogsErrorPath = os.path.join(
      preReloadLogsDestDir, preReloadLogsSaveErrorFile )
   return preReloadLogsErrorPath

def debugSyslog( level, log ):
   DEBUGSYSLOG = False
   msg = "%s: %s" % (
      str( dt.datetime.now() ), str( log )
   )
   if DEBUGSYSLOG:
      syslog.syslog( level, log )
   else:
      preReloadLogsErrorPath = getDebugLogPath()
      Tac.run( [ "mkdir", "-p",
                 os.path.dirname( preReloadLogsErrorPath ) ],
               asRoot=True, ignoreReturnCode=True,
               stdout=Tac.DISCARD, stderr=Tac.DISCARD )
      try:
         with open( preReloadLogsErrorPath, "a" ) as errorFile:
            errorFile.write( msg + "\n" )
      except Exception as e: # pylint: disable-msg=broad-except
         print( e )

class MonitorArBgpTcpSocketsThread( AsuPatchBase.AsuPatchBaseThread ):

   def __init__( self, sysdbPath ):
      pcl = PyClient.PyClient( sysdbPath, "Sysdb" )
      self.root = pcl.agentRoot()
      super( # pylint: disable=super-with-arguments
         MonitorArBgpTcpSocketsThread, self
      ).__init__()

   def addDynamicStageDependency( self, agentName, stageName ):
      debugSyslog(
         syslog.LOG_NOTICE,
         'Adding new dependency on %s' % ( stageName ) )
      stageDir = self.root.entity[
         Cell.path( "stageAgentStatus/shutdown" )
      ]
      if agentName not in stageDir:
         stageDir.newEntity( "Stage::AgentStatus", agentName )
      asuPatchAgent = stageDir[ agentName ]
      stageRequest = Tac.newInstance(
         "Stage::StageRequest", agentName, stageName, 30
      )
      asuPatchAgent.stageRequest.addMember( stageRequest )

   def markAsuPatchAgentStageComplete( self, agentName, stageName ):
      asuPatchAgent = self.root.entity[
         Cell.path( "stageAgentStatus/shutdown/%s" % ( agentName ) )
      ]
      stageKey = Tac.Value(
         "Stage::AgentStatusKey", agentName, stageName, "default"
      )
      asuPatchAgent.complete[ stageKey ] = True

   def doesBgpSocketStateMatch( self, permittedStates ):
      portList = [ 179 ]
      # ss -Hat exclude closed exclude time-wait '( ( dport = 179 or sport = 179 ) )'
      cmd = [ 'ss', '-Hat' ]
      for state in permittedStates:
         cmd += [ 'exclude', state ]
      portFilter = ' '.join(
        'dport = %d or sport = %d' % ( port, port ) for port in portList
      )
      cmd.append( "( ( %s ) )" % ( portFilter ) )
      filteredOutput = Tac.run(
         cmd, stdout=Tac.CAPTURE, asRoot=True, ignoreReturnCode=True
      ).strip()
      return len( filteredOutput ) == 0, filteredOutput

   def isStageStarted( self, stage ):
      shutProgressEntity = self.root.entity[
          Cell.path( "stage/shutdown/progress" )
      ].progress
      if 'default' in shutProgressEntity:
         return stage in shutProgressEntity[ 'default' ].stage
      return False

   def waitForBgpSockets( self ):
      try:
         waitFor(
            lambda: self.doesBgpSocketStateMatch( [ 'closed', 'time-wait' ] )[ 0 ],
            timeout=2, maxDelay=1,
            description="BGP sockets closure during ASU shutdown"
         )
         return True
      except Timeout:
         syslog.syslog(
            syslog.LOG_ERR,
            "Timeout: Unable to close all the bgp sockets during ASU shutdown"
         )
      return False

   def checkBgpSockets( self ):
      debugSyslog( syslog.LOG_NOTICE, "Waiting for BGP sockets to close.." )
      timeout = not self.waitForBgpSockets()
      if timeout:
         debugSyslog( syslog.LOG_NOTICE, "Killing BGP.." )
         Tac.run(
            [ 'killall', '-9', 'Bgp-main' ], asRoot=True,
            ignoreReturnCode=True
         )
         self.waitForBgpSockets()
      debugSyslog(
         syslog.LOG_NOTICE,
         "BGP socket close monitoring complete!"
      )

   def waitForStageAndStart( self, stage, func ):
      debugSyslog( syslog.LOG_NOTICE, "Waiting for stage completion.." )
      try:
         waitFor(
            lambda: self.isStageStarted( stage ) or self.isSSUPatchThreadStopped(),
            description="Wait till end of %s starts" % ( stage ),
            maxDelay=1, timeout=600
         )
      except Timeout:
         syslog.syslog(
            syslog.LOG_ERR,
            'Timed out waiting for Stage progression to hit %s' % ( stage )
         )
         return

      if self.isSSUPatchThreadStopped():
         return
      debugSyslog( syslog.LOG_NOTICE, "Starting payload now.." )
      func()

   def clearDebugTraces( self ):
      Tac.run(
         [ "timeout", "-k", "3", "3", "rm", "-f", getDebugLogPath() ],
         asRoot=True, ignoreReturnCode=True,
         stdout=Tac.DISCARD, stderr=Tac.DISCARD
      )

   def run( self ):
      socketCloseStage = 'LacpPacketsConfig'
      socketCloseAgent = "AsuPatchForBgpFlap"

      self.clearDebugTraces()
      debugSyslog(
          syslog.LOG_NOTICE,
          "Engaging %s dependency on %s for monitoring BGP sockets" % (
             socketCloseAgent, socketCloseStage
          )
      )
      self.addDynamicStageDependency( socketCloseAgent, socketCloseStage )
      self.waitForStageAndStart( socketCloseStage, self.checkBgpSockets )
      if self.isSSUPatchThreadStopped():
         debugSyslog( syslog.LOG_NOTICE, "Exit patch thread early" )
         return

      debugSyslog(
          syslog.LOG_NOTICE,
          "Agent %s now marking %s stage as complete" % (
             socketCloseAgent, socketCloseStage
          )
      )
      self.markAsuPatchAgentStageComplete(
         socketCloseAgent, socketCloseStage
      )

class ArBgpSocketsCheck( AsuPatchBase.AsuPatchBase ):
   def check( self ):
      thread = MonitorArBgpTcpSocketsThread( self.em.sysname() )
      thread.start()
      if hasattr( _asuData, 'asuPatchThreads' ):
         _asuData.asuPatchThreads.append( thread )
      return 0

   def reboot( self ):
      pass

# This method is executed by AsuPatch in the shutdown path by the src image
def execute( stageVal, *args, **kwargs ):
   obj = ArBgpSocketsCheck( 'ArBgpSocketsCheck' )
   return obj.execute( stageVal, *args, **kwargs )

# This block is used if this script is triggered by the event handler
if __name__ == "__main__":
   if int( os.environ.get( 'EVENT_COUNT', 0 ) ) > 0:
      socketCheckObj = ArBgpSocketsCheck( 'ArBgpSocketsCheck' )
      socketCheckObj.check()
   else:
      socketCloseStageMain = 'LacpPacketsConfig'
      socketCloseAgentMain = "AsuPatchForBgpFlap"
      debugSyslog(
          syslog.LOG_NOTICE,
          "On Init, engaging pseudo agent dependency on %s" % (
             socketCloseStageMain
          )
      )
      sockLockObj = MonitorArBgpTcpSocketsThread( "ar" )
      sockLockObj.addDynamicStageDependency(
         socketCloseAgentMain, socketCloseStageMain
      )
