#!/usr/bin/env python3
# # Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from Arnet import IpGenAddr
from CliDynamicSymbol import CliDynamicPlugin
from CliPlugin import MvpnLibCliModels as MLCM
from CliPlugin.RoutingBgpShowCli import (
   arBgpShowCmdDict,
   ArBgpShowOutput,
   BgpVrfRoutingContextDefaultOnly
)
from CliPlugin.MvpnCli import (
   mvpnRsvpLeafStatus,
   mvpnTunnelEncapStatus,
   vrfIdStatus,
   vrfNameStatus,
   mvpnIntfStatus,
   mldpTunnelDecapStatus,
   rsvpTunnelDecapStatus
)
from CliPlugin import BgpCliModels
import Tac
from TypeFuture import TacLazyType

BgpVpnModels = CliDynamicPlugin( "BgpVpnModels" )
MvpnHelperCli = CliDynamicPlugin( "MvpnHelperCli" )
MvpnCliModels = CliDynamicPlugin( "MvpnCliModels" )
ArBgpCliHandler = CliDynamicPlugin( "ArBgpCliHandler" )
BgpVpnModels = CliDynamicPlugin( "BgpVpnModels" )

arBgpShowCmdDict[ 'doShowBgpMvpn' ] = MvpnHelperCli.doShowBgpMvpn

PmsiTunnelType = TacLazyType( "Routing::Multicast::PmsiTunnelType" )
LdpP2mpFecElement = TacLazyType( "Ldp::LdpP2mpFecElement" )
MvpnRsvpSessionId = TacLazyType( "Routing::Multicast::MvpnRsvpSessionId" )
MldpOpaqueValTlvParser = TacLazyType( 'Mpls::MldpOpaqueValTlvParser' )

#-------------------------------------------------------------------------------
# Helper Methods
#-------------------------------------------------------------------------------
@ArBgpShowOutput( 'doShowBgpMvpn', arBgpModeOnly=True )
@BgpVrfRoutingContextDefaultOnly( cliModel=MvpnCliModels.MvpnRoutesVrfModel )
def doShowBgpMvpn( mode, **kwargs ):
   raise NotImplementedError( 'This method is abstract. ' +
      'If this was reached then something went wrong with MvpnHelperCli loading' )

def handlerMvpnV4IntraAsIpmsi( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnType1',
            rdValue=args.get( 'RD_VAL' ),
            mvpnOrigIp=args.get( 'ORIG_IP' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4Spmsi( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnType3',
            rdValue=args.get( 'RD_VAL' ),
            srcGrp=args.get( 'SRC_GRP' ),
            mvpnOrigIp=args.get( 'ORIG_IP' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4SourceActive( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnType5',
            rdValue=args.get( 'RD_VAL' ),
            srcGrp=args.get( 'SRC_GRP' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4SharedTree( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnType6',
            rdValue=args.get( 'RD_VAL' ),
            srcGrp=args.get( 'SRC_GRP' ),
            mvpnSourceAs=args.get( 'AS_NUM' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4SourceTree( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnType7',
            rdValue=args.get( 'RD_VAL' ),
            srcGrp=args.get( 'SRC_GRP' ),
            mvpnSourceAs=args.get( 'AS_NUM' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4VrfSpmsi( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnVrfType3',
            vrfName=args.get( 'VRF' ),
            srcGrp=args.get( 'SRC_GRP' ),
            mvpnOrigIp=args.get( 'ORIG_IP' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4VrfIntraAsIpmsi( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnVrfType1',
            vrfName=args.get( 'VRF' ),
            mvpnOrigIp=args.get( 'ORIG_IP' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4VrfSourceActive( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnVrfType5',
            vrfName=args.get( 'VRF' ),
            srcGrp=args.get( 'SRC_GRP' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4VrfSharedTree( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnVrfType6',
            vrfName=args.get( 'VRF' ),
            umhRdValue=args.get( 'RD_VAL' ),
            srcGrp=args.get( 'SRC_GRP' ),
            mvpnSourceAs=args.get( 'AS_NUM' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4VrfSourceTree( mode, args ):
   return doShowBgpMvpn( mode,
            nlriType='ipv4MvpnVrfType7',
            vrfName=args.get( 'VRF' ),
            umhRdValue=args.get( 'RD_VAL' ),
            srcGrp=args.get( 'SRC_GRP' ),
            mvpnSourceAs=args.get( 'AS_NUM' ),
            detail=args.get( 'detail' ) )

def handlerMvpnV4Summary( mode, args ):
   return ArBgpCliHandler.doShowIpBgpSummaryAcrImpl( mode, nlriAfiSafi='ipv4Mvpn' )

def handleBgpMvpnVrfRsvpLeaf( vrfId, vrfLeafStatus ):
   # Construct a MvpnVrfRsvpLeafStatus model base on the vrfLeafStatus input
   vrfStatus = MvpnCliModels.MvpnVrfRsvpLeafStatus()
   vrfStatus.vrfId = vrfId
   vrfStatus.tunnelProfile = vrfLeafStatus.tunnelProfile
   for sessionId, leafSet in vrfLeafStatus.rsvpSessionIdLeafSet.items():
      leafSetModel = MvpnCliModels.MvpnRsvpSessionIdLeafSet()
      for ip in leafSet.leaf:
         leafSetModel.leafNodes.append( ip )
      tunnelIdModel = MvpnCliModels.MvpnVrfRsvpLeafTunnelIdModel()
      tunnelIdModel.tunnelIds[ sessionId.tunnelId ] = leafSetModel
      vrfStatus.p2mpIds[ sessionId.p2mpId ] = tunnelIdModel
      vrfStatus.extendedTunnel = sessionId.extTunnelId
   return vrfStatus

def handleBgpMvpnRsvpLeaf( mode, args ):
   # Populate a MvpnRsvpLeafStatus model based on leafStatus content in Sysdb
   leafStatus = mvpnRsvpLeafStatus.rsvpVrfLeafStatus
   model = MvpnCliModels.MvpnRsvpLeafStatus()
   vrfName = args.get( 'VRF' )

   if vrfName is not None:
      # vrfName is set, only show one VRF
      vrfId = vrfNameStatus.nameToIdMap.vrfNameToId.get( vrfName )
      if ( vrfId is not None and vrfId in leafStatus ):
         vrfLeafStatus = leafStatus[ vrfId ]
         model.vrfs[ vrfName ] = handleBgpMvpnVrfRsvpLeaf( vrfId, vrfLeafStatus )
      return model

   # vrfName is not set, show all VRFs in leafStatus
   for vrfId, vrfLeafStatus in leafStatus.items():
      vrfEntry = vrfIdStatus.vrfIdToName.get( vrfId )
      if vrfEntry is not None:
         model.vrfs[ vrfEntry.vrfName ] = handleBgpMvpnVrfRsvpLeaf( vrfId,
                                                                    vrfLeafStatus )
   return model

def handleMvpnIpv4DecapCmd( mode, args ):
   model = MLCM.PmsiTunnelDecapStatus()

   protocol = args.get( 'PROTOCOL' )
   showMldp = protocol in ( 'mldp', None )
   showRsvp = protocol in ( 'rsvp', None )

   # Populate a collection of valid pmsiIntfIds. Only show a DecapStatus entry if
   # its pmsiIntfId is in this collection.
   validIntfIds = {}
   vrfName = args.get( 'VRF' )
   if vrfName is None:
      # When not filtering by VRF name, collect all pmsiIntfIds in mvpnIntfStatus
      # as long as its vrfId is valid, i.e. in vrfIdStatus.vrfIdToName
      for vrfId, vrfIntf in mvpnIntfStatus.vrfIntfId.items():
         if vrfId in vrfIdStatus.vrfIdToName:
            vrfName = vrfIdStatus.vrfIdToName[ vrfId ].vrfName
            validIntfIds[ vrfIntf.pmsiIntfId ] = vrfName
   else:
      # When filtering by VRF name, obtain corresponding pmsiIntfId of the given VRF
      # from mvpnIntfStatus. Check vrfNameStatus.nameToIdMap as well.
      vrfId = vrfNameStatus.nameToIdMap.vrfNameToId.get( vrfName )
      if ( vrfId is not None and mvpnIntfStatus is not None and
           vrfId in mvpnIntfStatus.vrfIntfId ):
         validIntfIds[ mvpnIntfStatus.vrfIntfId[ vrfId ].pmsiIntfId ] = vrfName

   if not validIntfIds:
      return model

   RsvpSessionId = Tac.Type( 'Routing::Multicast::MvpnRsvpSessionId' )
   rd = Tac.Value( "Arnet::RouteDistinguisher" )

   if showRsvp and rsvpTunnelDecapStatus:
      for treeId, decapStatus in rsvpTunnelDecapStatus.tunnelDecap.items():
         intfId = decapStatus.pmsiIntfId
         if intfId not in validIntfIds:
            continue

         vrfName = validIntfIds[ intfId ]
         vrfId = vrfNameStatus.nameToIdMap.vrfNameToId[ vrfName ]
         if vrfName in model.vrfs:
            tunnelDecapModel = model.vrfs[ vrfName ]
         else:
            model.vrfs[ vrfName ] = MLCM.PmsiTunnelDecapModel()
            model.vrfs[ vrfName ].vrfId = vrfId
            tunnelDecapModel = model.vrfs[ vrfName ]
         if not tunnelDecapModel.rsvpTeP2mpLsp:
            tunnelDecapModel.rsvpTeP2mpLsp = MLCM.PmsiRsvpTunnelDecap()
         rsvpTeP2mpLspModel = tunnelDecapModel.rsvpTeP2mpLsp
         # Convert RSVP treeId to MvpnRsvpSessionId to obtain information needed
         # for CAPI model population
         sessionId = RsvpSessionId.fromProtoData( treeId.protoData )
         if sessionId.p2mpId not in rsvpTeP2mpLspModel.p2mpIds:
            rsvpTeP2mpLspModel.p2mpIds[ sessionId.p2mpId ] =\
               MLCM.PmsiRsvpTunnelDecapTunnelId()
         tunnelIdModel = rsvpTeP2mpLspModel.p2mpIds[ sessionId.p2mpId ]

         if sessionId.tunnelId not in tunnelIdModel.tunnelIds:
            tunnelIdModel.tunnelIds[ sessionId.tunnelId ] =\
               MLCM.PmsiRsvpTunnelDecapExtTunnelId()
         extTunnelIdModel = tunnelIdModel.tunnelIds[ sessionId.tunnelId ]
         extTunnelIdModel.extTunnelIds[ sessionId.extTunnelId ] = intfId

   if showMldp and mldpTunnelDecapStatus:
      for treeId, decapStatus in mldpTunnelDecapStatus.tunnelDecap.items():
         intfId = decapStatus.pmsiIntfId
         if intfId not in validIntfIds:
            continue

         vrfName = validIntfIds[ intfId ]
         vrfId = vrfNameStatus.nameToIdMap.vrfNameToId[ vrfName ]
         if vrfName in model.vrfs:
            tunnelDecapModel = model.vrfs[ vrfName ]
         else:
            model.vrfs[ vrfName ] = MLCM.PmsiTunnelDecapModel()
            model.vrfs[ vrfName ].vrfId = vrfId
            tunnelDecapModel = model.vrfs[ vrfName ]
         if not tunnelDecapModel.mldpP2mp:
            tunnelDecapModel.mldpP2mp = MLCM.PmsiMldpTunnelDecap()
         mldpP2mpModel = tunnelDecapModel.mldpP2mp
         # Parse MLDP treeId to obtain information needed for CAPI model population
         p2mpFecElement = LdpP2mpFecElement.fromProtoData( treeId.protoData )

         rootIp = p2mpFecElement.rootIp
         if rootIp not in mldpP2mpModel.rootAddrs:
            mldpP2mpModel.rootAddrs[ rootIp ] = MLCM.PmsiMldpOpaqueValueModel()
         mldpDecapEntry = mldpP2mpModel.rootAddrs[ rootIp ]

         if mldpDecapEntry.opaqueValues is None:
            mldpDecapEntry.opaqueValues = MLCM.PmsiMldpOpaqueValue()
         opaqueValues = mldpDecapEntry.opaqueValues

         opaqueVal = p2mpFecElement.opaqueVal
         tlvParser = MldpOpaqueValTlvParser( opaqueVal, True )
         if not tlvParser.valid:
            # Parsing of the opaque value failed, populate invalidOpaqueValues with
            # the hex string of opaqueVal as key.
            opaqueValues.invalidOpaqueValues[ opaqueVal.hex() ] = intfId
            continue

         opaqueStr = tlvParser.getOpaqueString( False, False )
         if 'LSP' in opaqueStr:
            # opaqueStr: 'LSP ID: 1'
            lspIdStrs = opaqueStr.split( ":" )
            opaqueValues.genericLspIds[ int( lspIdStrs[ 1 ] ) ] = intfId
         elif 'RD' in opaqueStr:
            # opaqueStr: 'S: 1.1.1.1 G: 224.1.1.1 RD: 0x0100000001000000'
            # sgr: ['S:', '1.1.1.1', 'G:', '224.1.1.1', 'RD:', '0x0100000001000000']
            sgr = opaqueStr.split( " " )
            source = IpGenAddr( sgr[ 1 ] )
            group = IpGenAddr( sgr[ 3 ] )
            # Conver RD from nbo string to the correct format according to the type
            rd.rdNbo = int( sgr[ 5 ], 16 )
            rdStr = rd.stringValue

            if source not in opaqueValues.transitV4Src:
               opaqueValues.transitV4Src[ source ] = MLCM.PmsiMldpOpaqueValueGroup()
            opaqueValueGroup = opaqueValues.transitV4Src[ source ]

            if group not in opaqueValueGroup.groups:
               opaqueValueGroup.groups[ group ] = MLCM.PmsiMldpOpaqueValueRd()
            opaqueValueRd = opaqueValueGroup.groups[ group ]
            opaqueValueRd.rds[ rdStr ] = intfId

      # The entries in DecapStatus don't map to entries in opaqueValues one-to-one,
      # since multiple opaqueValues can have the same rootIp. In order to prevent
      # empty optional collections in opaqueValues from showing up, we need to walk
      # through all mldpDecapEntries to manually set those collections to None.
      for tunnelDecapModel in model.vrfs.values():
         mldpP2mpModel = tunnelDecapModel.mldpP2mp
         if mldpP2mpModel and mldpP2mpModel.rootAddrs:
            for _, mldpDecapEntry in mldpP2mpModel.rootAddrs.items():
               opaqueValues = mldpDecapEntry.opaqueValues
               if not opaqueValues.invalidOpaqueValues:
                  opaqueValues.invalidOpaqueValues = None
               if not opaqueValues.genericLspIds:
                  opaqueValues.genericLspIds = None
               if not opaqueValues.transitV4Src:
                  opaqueValues.transitV4Src = None

   return model

def handleEncapTunnel( tunnelType, encapTunnel ):
   sourceModel = BgpVpnModels.EncapTunnelModel()
   sourceModel.vlans = None
   sourceModel.pmsiTunnel = BgpCliModels.PmsiTunnelId()
   if tunnelType == PmsiTunnelType.mLdpP2mpLsp:
      mldpP2mpModel = BgpCliModels.PmsiMldpTunnelId()
      p2mpFecElement = LdpP2mpFecElement.fromProtoData( encapTunnel.protoData )
      tlvParser = MldpOpaqueValTlvParser( p2mpFecElement.opaqueVal, True )
      mldpP2mpModel.rootAddress = p2mpFecElement.rootIp
      if tlvParser.valid:
         opaqueValue = tlvParser.getOpaqueString( False, False )
         # Format the RD from hex to string representation for SourceGroupRd
         if 'RD' in opaqueValue:
            rdString = opaqueValue.split()[ -1 ]
            rd = Tac.newInstance( "Arnet::RouteDistinguisher" )
            rd.rdNbo = int( rdString, 0 )
            opaqueValue = opaqueValue.replace( rdString, rd.stringValue )
         mldpP2mpModel.opaqueValue = opaqueValue
      else:
         mldpP2mpModel.opaqueValue = p2mpFecElement.opaqueVal.hex()
         mldpP2mpModel.invalidOpaqueValue = True
      sourceModel.pmsiTunnel.mldpP2mp = mldpP2mpModel
   elif tunnelType == PmsiTunnelType.rsvpTeP2mpLsp:
      rsvpP2mpModel = BgpCliModels.PmsiRsvpTunnelId()
      sessionId = MvpnRsvpSessionId.fromProtoData( encapTunnel.protoData )
      rsvpP2mpModel.p2mpId = sessionId.p2mpId
      rsvpP2mpModel.tunnelId = sessionId.tunnelId
      rsvpP2mpModel.extTunnelId = sessionId.extTunnelId
      sourceModel.pmsiTunnel.rsvpP2mp = rsvpP2mpModel
   return sourceModel

def handleMvpnVrfEncapStatusEntry( vrfId, vrfEncapStatus ):
   vrfModel = BgpVpnModels.EncapVrfModel()
   vrfModel.vrfId = vrfId
   groupModel = BgpVpnModels.EncapGroupModel()
   for rk, encapTunnel in vrfEncapStatus.encapTunnel.items():
      if rk.group not in groupModel.groups:
         groupModel.groups[ rk.group ] = BgpVpnModels.EncapSourceModel()
      groupModel.groups[ rk.group ].sources[ rk.source ] = handleEncapTunnel(
                                                      vrfEncapStatus.pmsiTunnelType,
                                                      encapTunnel.encapTunnel )
   vrfModel.tunnelTypes[ vrfEncapStatus.pmsiTunnelType ] = groupModel
   return vrfModel

def handleMvpnEncapStatus( mode, args ):
   model = BgpVpnModels.MvpnEncapStatusModel()
   vrfName = args.get( 'VRF' )
   protocol = args.get( 'PROTOCOL', None )
   if vrfName is not None:
      vrfId = vrfNameStatus.nameToIdMap.vrfNameToId.get( vrfName )
      if ( vrfId is not None and vrfId in mvpnTunnelEncapStatus.vrfStatus ):
         vrfEncapStatus = mvpnTunnelEncapStatus.vrfStatus[ vrfId ]
         if not protocol or protocol == vrfEncapStatus.pmsiTunnelType[ : 4 ].lower():
            model.vrfs[ vrfName ] = handleMvpnVrfEncapStatusEntry( vrfId,
                                                                   vrfEncapStatus )
   else:
      for vrfId, vrfEncapStatus in mvpnTunnelEncapStatus.vrfStatus.items():
         vrfEntry = vrfIdStatus.vrfIdToName.get( vrfId )
         if vrfEntry is not None and ( not protocol or
               protocol == vrfEncapStatus.pmsiTunnelType[ : 4 ].lower() ):
            model.vrfs[ vrfEntry.vrfName ] = handleMvpnVrfEncapStatusEntry( vrfId,
                                                                   vrfEncapStatus )
   return model
