# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pkgdeps: rpmwith %{_libdir}/libRsvp.so*

import os

import BasicCliModes
import CliCommand
import CliToken.Clear
import CliToken.Rsvp
from CliPlugin import (
   AclCli,
   MplsCli,
   RsvpShowCli,
)
import LazyMount
import SharkLazyMount
import SmashLazyMount
import Tac
from TypeFuture import TacLazyType

aclStatus = None
aclCheckpoint = None
rsvpCommandRequest = None
rsvpCommandStatus = None
rsvpSharkStatus = None
rsvpSysdbStatus = None
rsvpMessageCounters = None
rsvpMessageCountersSnapshot = None
rsvpErrorCounters = None
rsvpErrorCountersSnapshot = None

RsvpSessionRole = TacLazyType( 'Rsvp::RsvpSessionRole' )
RsvpSpOperState = TacLazyType( 'Rsvp::RsvpSpOperState' )
RsvpLspType = TacLazyType( 'Rsvp::RsvpLspType' )

#--------------------------------------------------------------------------------
# clear mpls rsvp counters ( ip | ipv6 ) access-list
#--------------------------------------------------------------------------------
class ClearMplsRsvpCountersIpAccessListCmd( CliCommand.CliCommandClass ):
   syntax = 'clear mpls rsvp counters ( ip | ipv6 ) access-list'
   data = {
      'clear': CliToken.Clear.clearKwNode,
      'mpls': MplsCli.mplsMatcherForClear,
      'rsvp': CliToken.Rsvp.rsvpMatcherForClear,
      'counters' : 'Clear RSVP counters',
      'ip': AclCli.ipKwForClearServiceAclMatcher,
      'ipv6' : AclCli.ipv6KwMatcherForClearServiceAcl,
      'access-list': AclCli.accessListKwMatcherForServiceAcl,
   }

   @staticmethod
   def handler( mode, args ):
      aclType = 'ipv6' if 'ipv6' in args else 'ip'
      AclCli.clearServiceAclCounters( mode, aclStatus, aclCheckpoint, aclType )

BasicCliModes.EnableMode.addCommandClass( ClearMplsRsvpCountersIpAccessListCmd )

# --------------------------------------------------------------------------------
# clear mpls rsvp counters
# --------------------------------------------------------------------------------
class ClearMplsRsvpCounters( CliCommand.CliCommandClass ):
   syntax = 'clear mpls rsvp counters'
   data = {
      'clear' : CliToken.Clear.clearKwNode,
      'mpls' : MplsCli.mplsMatcherForClear,
      'rsvp' : CliToken.Rsvp.rsvpMatcherForClear,
      'counters' : 'Clear RSVP counters',
   }

   @staticmethod
   def handler( mode, args ):
      rsvpMessageCountersSnapshot.rxSnapshotFrom( rsvpMessageCounters )
      rsvpMessageCountersSnapshot.txSnapshotFrom( rsvpMessageCounters )
      rsvpErrorCountersSnapshot.rxSnapshotFrom( rsvpErrorCounters )
      rsvpErrorCountersSnapshot.txSnapshotFrom( rsvpErrorCounters )

BasicCliModes.EnableMode.addCommandClass( ClearMplsRsvpCounters )

# --------------------------------------------------------------------------------
# clear mpls rsvp session ( all | filters ) [ detail ]
# -------------------------------------------------------------------------------
def applyClearFilters( sessionStateColl, spStatusColl, filterArg, mode ):
   ''' Applies the filters in filterArg to a list of sessionStates and spStatuses
   and returns a list of matching ( spId, spStatus ) tuples.
   '''
   tacFilter = RsvpShowCli.createTacFilter( filterArg )

   clearSps = []
   for sessionState in sessionStateColl.sessionState.values():
      sessionCliId = rsvpSysdbStatus.sessionIdToCliIdColl.sessionIdToCliId.\
         get( sessionState.sessionId )
      if not tacFilter.matchSession( sessionState, sessionCliId ):
         continue
      for spId in sessionState.spMember:
         spStatus = spStatusColl.spStatus.get( spId )
         if not spStatus or not tacFilter.matchSp( spStatus ):
            continue
         clearSps.append( ( spId, spStatus ) )

   return clearSps

def applyP2mpClearFilters( sessionStateColl, spGroupStateColl, spStatusColl,
                           filterArg, mode ):
   ''' Applies the filters in filterArg to the P2MP sessions and returns a list of
   ( spGroupId, spGroupState ) tuples from members spGroups of the matching sessions.
   '''
   tacFilter = RsvpShowCli.createTacFilter( filterArg )

   clearSpGroupIds = []
   for sessionState in sessionStateColl.sessionState.values():
      if not tacFilter.matchP2mpSession(
            sessionState, spGroupStateColl, spStatusColl ):
         continue
      # If a filter matches a P2MP session, clear all member spGroups.
      # No equivalent function for "matchSp" as the P2MP filters always return the
      # whole session, not individual LSPs or sub-groups.
      for spGroupId in sessionState.spGroupMember:
         spGroupState = spGroupStateColl.spGroupState.get( spGroupId )
         if not spGroupState:
            continue
         clearSpGroupIds.append( ( spGroupId, spGroupState ) )

   return clearSpGroupIds

class ClearMplsRsvpSessionCmd( CliCommand.CliCommandClass ):
   syntax = 'clear mpls rsvp session ( FILTERS | all ) [ detail ]'
   data = {
      'clear' : CliToken.Clear.clearKwNode,
      'mpls' : MplsCli.mplsMatcherForClear,
      'rsvp' : CliToken.Rsvp.rsvpMatcherForClear,
      'session' : 'Clear RSVP sessions',
      'FILTERS' : RsvpShowCli.MplsRsvpSessionFilterArg,
      'all' : 'Clear all RSVP sessions',
      'detail' : 'Display more information about cleared sessions',
   }

   @staticmethod
   def handler( mode, args ):
      filterArg = args.get( 'filterArg' ) or []
      allSessions = 'all' in args
      detail = 'detail' in args

      # Btest expects Tac.runActivities to be run in order to trigger timer tasks.
      # Use sleep=False for the btest and sleep=True otherwise.
      cliClearTest = os.environ.get( 'CLI_CLEAR_TEST' )
      sleep = not cliClearTest

      # If any of the status collections are not initialized, return early.
      if not ( rsvpSharkStatus.sessionStateColl and
               rsvpSharkStatus.spStatusColl and
               rsvpSharkStatus.p2mpSessionStateColl and
               rsvpSharkStatus.p2mpSpGroupStateColl and
               rsvpSharkStatus.p2mpSpStatusColl and
               rsvpCommandStatus.clearAckColl and
               rsvpCommandStatus.p2mpClearAckColl and
               rsvpSysdbStatus.sessionIdToCliIdColl ):
         return

      sessionStateColl = rsvpSharkStatus.sessionStateColl
      spStatusColl = rsvpSharkStatus.spStatusColl
      p2mpSessionStateColl = rsvpSharkStatus.p2mpSessionStateColl
      p2mpSpGroupStateColl = rsvpSharkStatus.p2mpSpGroupStateColl
      p2mpSpStatusColl = rsvpSharkStatus.p2mpSpStatusColl

      clearRequests = rsvpCommandRequest.clearRequestColl.clearRequest
      clearAcks = rsvpCommandStatus.clearAckColl.clearAck
      p2mpClearRequests = rsvpCommandRequest.p2mpClearRequestColl.clearRequest
      p2mpClearAcks = rsvpCommandStatus.p2mpClearAckColl.clearAck

      # Clear any stale clearRequests and wait for clearAck collection to be drained
      # before issuing new clearRequests.
      if clearRequests or p2mpClearRequests:
         mode.addWarning( 'Previous clear requests are present, emptying the clear '
                          'request queue before clearing sessions' )
         clearRequests.clear()
         p2mpClearRequests.clear()
      try:
         Tac.waitFor( lambda: not clearAcks and not p2mpClearAcks,
                      description='any preexisting clear acks to be drained',
                      warnAfter=None, sleep=sleep, maxDelay=1, timeout=600 )
      except Tac.Timeout:
         mode.addWarning(
            'Timed out during cleanup from previously interrupted clear command' )
         return
      except KeyboardInterrupt:
         return
      except Exception as e:
         raise e

      # create clearSps, list of ( spId, spStatus ) tuples
      if allSessions:
         clearSps = spStatusColl.spStatus.items()
         clearSpGroupIds = p2mpSpGroupStateColl.spGroupState.items()
      else:
         clearSps = applyClearFilters( sessionStateColl, spStatusColl,
                                       filterArg, mode )
         clearSpGroupIds = applyP2mpClearFilters( p2mpSessionStateColl,
                                                  p2mpSpGroupStateColl,
                                                  p2mpSpStatusColl,
                                                  filterArg, mode )

      for spId, spStatus in clearSps:
         if detail:
            mode.addMessage( f'Clearing { spStatus.sessionName }' )
         clearRequests.add( spId )
      for spGroupId, spGroupState in clearSpGroupIds:
         if detail:
            mode.addMessage( f'Clearing { spGroupState.sessionName }' )
         p2mpClearRequests.add( spGroupId )

      try:
         Tac.waitFor( lambda: ( len( clearAcks ) >= len( clearRequests ) and
                                len( p2mpClearAcks ) >= len( p2mpClearRequests ) ),
                      description='all specified sessions to be cleared',
                      warnAfter=None, sleep=sleep, maxDelay=1, timeout=600 )
      except Tac.Timeout:
         mode.addWarning( 'Timed out, some sessions may not have been cleared' )
      except KeyboardInterrupt:
         mode.addWarning( 'Clearing interrupted, stop clearing' )
      finally:
         totalClearedSessions = len( clearAcks ) + len( p2mpClearAcks )
         mode.addMessage( f'Cleared { totalClearedSessions } sessions' )
         clearRequests.clear()
         p2mpClearRequests.clear()

BasicCliModes.EnableMode.addCommandClass( ClearMplsRsvpSessionCmd )

def Plugin( entityManager ):
   global aclStatus
   aclStatus = LazyMount.mount( entityManager, "acl/status/all", "Acl::Status", "r" )

   global aclCheckpoint
   aclCheckpoint = LazyMount.mount( entityManager, "acl/checkpoint",
                                    "Acl::CheckpointStatus", "w" )

   global rsvpCommandRequest
   RsvpCommandRequest = Tac.Type( "Rsvp::RsvpCommandRequest" )
   rsvpCommandRequest = LazyMount.mount( entityManager,
                                         RsvpCommandRequest.mountPath,
                                         "Rsvp::RsvpCommandRequest", 'w' )

   global rsvpCommandStatus
   RsvpCommandStatus = Tac.Type( "Rsvp::RsvpCommandStatus" )
   rsvpCommandStatus = LazyMount.mount( entityManager,
                                        RsvpCommandStatus.mountPath,
                                        "Rsvp::RsvpCommandStatus", 'r' )

   global rsvpSharkStatus
   autoUnmount = True
   RsvpSharkStatus = Tac.Type( "Rsvp::RsvpSharkStatus" )
   rsvpSharkStatus = SharkLazyMount.mount( entityManager,
                                           RsvpSharkStatus.mountPath,
                                           "Rsvp::RsvpSharkStatus",
                                           SharkLazyMount.mountInfo( 'shadow' ),
                                           autoUnmount )

   global rsvpSysdbStatus
   RsvpSysdbStatus = Tac.Type( "Rsvp::RsvpSysdbStatus" )
   rsvpSysdbStatus = LazyMount.mount( entityManager,
                                 RsvpSysdbStatus.mountPath,
                                 "Rsvp::RsvpSysdbStatus", 'rS' )

   global rsvpMessageCounters
   RsvpMessageCounterColl = Tac.Type( "Rsvp::Smash::MessageCounterColl" )
   rsvpMessageCounters = SmashLazyMount.mount(
                           entityManager,
                           RsvpMessageCounterColl.mountPath,
                           "Rsvp::Smash::MessageCounterColl",
                           SmashLazyMount.mountInfo( 'reader' ) )

   global rsvpMessageCountersSnapshot
   RsvpMessageCounterColl = Tac.Type( "Rsvp::Smash::MessageCounterColl" )
   rsvpMessageCountersSnapshot = SmashLazyMount.mount(
                           entityManager,
                           RsvpMessageCounterColl.snapshotMountPath,
                           "Rsvp::Smash::MessageCounterColl",
                           RsvpMessageCounterColl.writerMountInfo() )

   global rsvpErrorCounters
   RsvpErrorCounterColl = Tac.Type( "Rsvp::Smash::ErrorCounterColl" )
   rsvpErrorCounters = SmashLazyMount.mount(
                           entityManager,
                           RsvpErrorCounterColl.mountPath,
                           "Rsvp::Smash::ErrorCounterColl",
                           SmashLazyMount.mountInfo( 'reader' ) )

   global rsvpErrorCountersSnapshot
   RsvpErrorCounterColl = Tac.Type( "Rsvp::Smash::ErrorCounterColl" )
   rsvpErrorCountersSnapshot = SmashLazyMount.mount(
                           entityManager,
                           RsvpErrorCounterColl.snapshotMountPath,
                           "Rsvp::Smash::ErrorCounterColl",
                           RsvpErrorCounterColl.writerMountInfo() )
