# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCliModes
import CliCommand
from CliPlugin.IpAddrMatcher import IpAddrMatcher
from CliPlugin.RsvpLerShowCli import RsvpLerTunnelFilterArg
from CliToken.Refresh import refreshMatcherForExec
import LazyMount
import Tac
from TypeFuture import TacLazyType
import Toggles.RsvpToggleLib

RsvpLerCliConfig = TacLazyType( 'Rsvp::RsvpLerCliConfig' )
RsvpLerTunnelSpecId = TacLazyType( 'Rsvp::RsvpLerTunnelSpecId' )
tunnelSourceCli = TacLazyType( 'Rsvp::RsvpLerTunnelSource' ).tunnelSourceCli
RsvpSessionType = TacLazyType( 'Rsvp::RsvpSessionType' )
p2pLspTunnel = RsvpSessionType.p2pLspTunnel

lerClientDir = None

def doOptimizeRsvpTunnel( mode, tunnelSpec, avoidEcmpHops=None ):
   if tunnelSpec.sessionType != p2pLspTunnel:
      # For p2mp, we don't support avoiding ecmp
      if avoidEcmpHops:
         mode.addError( "This command only supported for p2p" )
         return
   else:
      # For p2p, if the avoid ecmp hops are entered, we copy them into the tunnelSpec
      tunnelSpec.refreshAvoidEcmpHop.clear()
      if avoidEcmpHops:
         for hop in avoidEcmpHops:
            tunnelSpec.refreshAvoidEcmpHop.add(
               Tac.ValueConst( 'Arnet::IpGenAddr', hop ) )
   # Increase forceOptimize and changeCount
   tunnelSpec.forceOptimize += 1
   tunnelSpec.changeCount += 1

def optimizeRsvpTunnels( mode, args ):
   filters = args.get( 'FILTERS' )
   if not lerClientDir:
      return
   # We optimize tunnels from only these sources
   clientNames = [ 'cli', 'mvpn' ]
   if 'all' in args:
      # Optimize all the tunnels
      for clientName in clientNames:
         lerClient = lerClientDir.get( clientName )
         if not lerClient:
            continue
         for tunnelSpec in lerClient.tunnelSpec.values():
            doOptimizeRsvpTunnel( mode, tunnelSpec )
   elif filters:
      # Optimize only those tunnels that match the passed filter.
      for filterType, arg in filters:
         # TODO BUG982943: Handle all filter types
         if filterType != 'tunnelFilter' or len( arg ) < 2 or \
               arg[ 0 ] != 'tunnelName':
            continue
         tunnelName = arg[ 1 ]
         for clientName in clientNames:
            lerClient = lerClientDir.get( clientName )
            if not lerClient:
               continue
            for tunnelSpec in lerClient.tunnelSpec.values():
               if tunnelSpec.tunnelName == tunnelName:
                  doOptimizeRsvpTunnel( mode, tunnelSpec, args.get( 'IP_ADDR' ) )

#-----------------------------------------------------------------------------------
# 'refresh rsvp tunnel optimization
#    ( all | ( FILTERS [ avoid ecmp hop { IP_ADDR } ] ) )'
# in privileged EXEC mode
#-----------------------------------------------------------------------------------
class RsvpOptimizationExecCmd( CliCommand.CliCommandClass ):
   if Toggles.RsvpToggleLib.toggleRsvpLerRefreshAvoidEcmpHopEnabled():
      syntax = (
         ''' refresh rsvp tunnel optimization
         ( all | ( FILTERS [ avoid ecmp hop { IP_ADDR } ] ) ) ''' )
   else:
      syntax = 'refresh rsvp tunnel optimization ( all | FILTERS )'
   data = {
      'refresh': refreshMatcherForExec,
      'rsvp': 'Refresh RSVP information',
      'tunnel': 'Refresh RSVP tunnels',
      'optimization': 'Refresh RSVP tunnels by optimizing them',
      'all': 'Optimize all RSVP tunnels',
      'FILTERS': RsvpLerTunnelFilterArg,
   }
   if Toggles.RsvpToggleLib.toggleRsvpLerRefreshAvoidEcmpHopEnabled():
      data |= {
         'avoid': 'Avoid a hop (or a set of hops) on the paths',
         'ecmp': 'Move the tunnel away from a certain set of ECMP hops',
         'hop': 'Recompute the path for tunnel avoiding certain ECMP hops',
         'IP_ADDR': CliCommand.Node(
            matcher=IpAddrMatcher( helpdesc='The set of ECMP hops to be avoided' ),
            maxMatches=255 )
      }

   handler = optimizeRsvpTunnels

BasicCliModes.EnableMode.addCommandClass( RsvpOptimizationExecCmd )

#--------------------------------------------------------------------------------
# Plugin
#--------------------------------------------------------------------------------
def Plugin( entityManager ):
   global lerClientDir
   lerClientDir = LazyMount.mount( entityManager, 'mpls/rsvp/lerClient',
                                   'Tac::Dir', 'wi' )
