# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac, SharedMem, SmashLazyMount
from CliPlugin import RoutingIsisCli
from CliPlugin import RoutingIsisShowSegmentRoutingModel
from CliPlugin.TunnelFibCli import getFecFromIntfId
from CliPlugin.TunnelModels import sortKey
from CliPlugin.TunnelCli import (
   getTunnelEntryAndIdFromIndexAndAf,
   getTunnelIndexFromId,
   getViaModel,
)
from CliPlugin.SegmentRoutingCliShowHelper import (
   getAdjSidOrigin,
   _populateSRGBValues,
   _populateSrCommonHeader,
   _getReachabilityAlgorithm,
   _getSegmentRoutingGlobalBlocksPeersCount,
   _getSelfOriginatedSegmentStatistics
)
import CliPlugin.RoutingIsisShowSegmentRoutingCli
import sys
from CliModel import cliPrinted
from TypeFuture import TacLazyType
from IsisShowLib import showIsisCommand

_routingIsisCli = RoutingIsisCli

_srDataPlaneEnum = Tac.Type( "Routing::SegmentRoutingCli::SrDataPlane" )
_srAdjAllocationType = Tac.Type( "Routing::SegmentRoutingCli::SrAdjAllocationType" )
_interfaceTypeEnum = Tac.Type( 'Routing::Isis::InterfaceType' )
_addrFamilyEnum = Tac.Type( "Arnet::AddressFamily" )

srTunnelTable = CliPlugin.RoutingIsisShowSegmentRoutingCli.srTunnelTable

FecIdIntfId = TacLazyType( 'Arnet::FecIdIntfId' )

def _getFlexAlgoConfig():
   return _routingIsisCli.flexAlgoConfig

def _getIsisIntfConfig():
   return _routingIsisCli.isisConfig.intfConfig

def _getIsisConfigAndStatus( mode, instanceId, instanceName ):
   instConfig, instStatus = _routingIsisCli.getIsisConfigAndStatus( mode, instanceId,
                                                                    instanceName )
   if instConfig and instStatus:
      if not _routingIsisCli.isisSrIsActive(mode, instanceName ):
         return None, None
      if instConfig.srDataPlane == _srDataPlaneEnum.srDataPlaneNone:
         return None, None

   return instConfig, instStatus

def _populateSrCommonModel(model, instConfig, instanceName, srSysdbStatus):
   model._instanceName = instanceName # pylint: disable=W0212
   model.systemId = str( instConfig.systemId )
   
   RouterId = Tac.Type( "Mpls::RouterId" )
   rId = RouterId()
   rId.stringValue = instConfig.systemId
   model.hostname = _routingIsisCli.getHostName( rId )
   _populateSrCommonHeader( model, instConfig.srDataPlane, srSysdbStatus ) 

#TODO this could go to others ... someone could benefit from it.
class _SrGuardMounts:
   def __init__(self, path = 'segmentrouting/isis/default',
         smashType = 'Smash::SegmentRouting::Status'):
      self.shmemEm = None
      self.smi = None
      self.srSmashStatus = None
      self.path = path
      self.smashType = smashType
      self.srSysdbStatus = None
   
   def __enter__(self):
      self.smi = SmashLazyMount.mountInfo( 'reader' )
      self.srSmashStatus = SmashLazyMount.mount(
            _routingIsisCli.entityManager, self.path, self.smashType,self.smi,
            autoUnmount=True )
      self.srSysdbStatus = _routingIsisCli.srSysdbStatusDir.get( 'default' )
      return ( self.srSmashStatus, self.srSysdbStatus )
   def __exit__(self, exc_type, exc_val, exc_tb):
      return

def _populateAdjacencyValues(result, instConfig, instStatus):
   srAdjSidAllocationModeEnumMapping = dict([
                        (_srAdjAllocationType.srAdjacencyAllocationNone, "None"),
                        (_srAdjAllocationType.srAdjacencyAllocationSrOnly, "SrOnly"),
                        (_srAdjAllocationType.srAdjacencyAllocationAll, "All"),
      (_srAdjAllocationType.srAdjacencyAllocationSrOnlyBackup, "SrOnly"),
      (_srAdjAllocationType.srAdjacencyAllocationAllBackup, "All")])

   result.adjSidAllocationMode = srAdjSidAllocationModeEnumMapping[
         instConfig.srAdjSegmentAlloc ]
   if instConfig.srAdjSegmentAlloc != _srAdjAllocationType.srAdjacencyAllocationNone:
      result.adjSidPoolBase = instStatus.srAdjBase
      result.adjSidPoolSize = instStatus.srAdjRange

#------------------------------------------------------------------------------------
# The "show isis segment-routing" command.
#------------------------------------------------------------------------------------
def getSegmentRoutingsSummaryModel(mode, cmdPrefix=None, instanceId=None,
                                   instanceName=None, vrfName=None):

   instConfig, instStatus = _getIsisConfigAndStatus(mode, instanceId, instanceName)

   if not instConfig or not instStatus:
      return None
   
   result = RoutingIsisShowSegmentRoutingModel.SrSummaryModel()
   
   with _SrGuardMounts() as ( srSmashStatus, srSysdbStatus ):
      if not srSysdbStatus:
         return None
      _populateSrCommonModel( result, instConfig, instanceName, srSysdbStatus )
      _populateSRGBValues(result, srSysdbStatus)
      _populateAdjacencyValues(result, instConfig, instStatus)
      result.reachabilityAlgorithm = _getReachabilityAlgorithm()
      result.srPeerCount = _getSegmentRoutingGlobalBlocksPeersCount( srSmashStatus )
      result.inspectProxyAttachedFlag = instConfig.inspectProxyAttachedFlag
      result.selfOriginatedSegmentStatistics = _getSelfOriginatedSegmentStatistics(
         srSmashStatus, srSysdbStatus, instConfig.systemId, instConfig.instanceId )
      result.mappingServer = result.selfOriginatedSegmentStatistics.\
                                                               proxyNodeSidCount > 0
   
   return result

def _populateGlobalBlocks(globalBlocks, srSmashStatus, instConfig, srSysdbStatus):
   item = RoutingIsisShowSegmentRoutingModel.GlobalBlockModel()
   RouterId = Tac.Type( "Mpls::RouterId" )
   rId = RouterId()
   rId.stringValue = instConfig.systemId
   item.systemId = instConfig.systemId
   item.hostname = _routingIsisCli.getHostName( rId )
   item.base = srSysdbStatus.labelRange.base
   item.size = srSysdbStatus.labelRange.size
   globalBlocks.append(item)
   
   for index, entry in sorted( srSmashStatus.globalBlock.items() ):
      item = RoutingIsisShowSegmentRoutingModel.GlobalBlockModel()
      item.systemId = index.stringValue
      item.hostname = _routingIsisCli.getHostName( index )
      item.base = entry.base
      item.size = entry.size
      globalBlocks.append(item)

#------------------------------------------------------------------------------------
# The "show isis segment-routing global-blocks" command.
#------------------------------------------------------------------------------------
def getSegmentRoutingsGlobalBlocksSummaryModel(mode, cmdPrefix=None, instanceId=None,
                                               instanceName=None, vrfName=None):
   instConfig, instStatus = _getIsisConfigAndStatus(mode, instanceId, instanceName)

   if not instConfig or not instStatus:
      return None

   result = RoutingIsisShowSegmentRoutingModel.GlobalBlockSummaryModel()

   with _SrGuardMounts() as (srSmashStatus, srSysdbStatus):
      if not srSysdbStatus:
         return None
      _populateSrCommonModel(result, instConfig, instanceName, srSysdbStatus)
      _populateSRGBValues(result, srSysdbStatus)
      result.srPeerCount = _getSegmentRoutingGlobalBlocksPeersCount(srSmashStatus)
      _populateGlobalBlocks( result.globalBlocks, srSmashStatus, instConfig, 
                             srSysdbStatus )

   return result

def getProtectionMode( protection ):
   protectionStr = "unprotected" if protection == "disabled" else protection
   return protectionStr

def getSrlgProtection( srlg ):
   if srlg == "srlgStrict":
      return "strict"
   elif srlg == "srlgLoose":
      return "loose"
   else:
      return "none"

#------------------------------------------------------------------------------------
# The "show isis segment-routing prefix-segments" command.
#------------------------------------------------------------------------------------
def getSegmentRoutingsPrefixSegmentsSummaryModel( mode, cmdPrefix=None,
                                                  instanceId=None, instanceName=None,
                                                  selfOrig=False, afSegments=0,
                                                  vrfName=None, fd=None ):
   instConfig, instStatus = _getIsisConfigAndStatus( mode, instanceId, instanceName )
   flexAlgoConfig = _getFlexAlgoConfig()
   capiRevision = mode.session_.requestedModelRevision()
   # Return cliPrinted( <modelclass> ) to convey the cliPrint library is being used
   modelSchema = \
         cliPrinted( RoutingIsisShowSegmentRoutingModel.PrefixSegmentSummaryModel )
   if not instConfig or not instStatus:
      return None
   isisHostnameMap = _routingIsisCli.isisSystemIdHostnameMap
   with _SrGuardMounts() as ( srSmashStatus, srSysdbStatus ):
      if not srSysdbStatus:
         return None
      if not fd:
         sys.stdout.flush()
         fd = sys.stdout.fileno()
      fmt = mode.session_.outputFormat()
      if isinstance( srSmashStatus, SharedMem.AutoUnmountEntityProxy ):
         srSmashStatus = SmashLazyMount.force( srSmashStatus )
      helper = Tac.newInstance( "Isis::Cli::SrPrefixSegmentHelper",
                                instConfig, srSysdbStatus, isisHostnameMap,
                                flexAlgoConfig, srSmashStatus, afSegments, selfOrig,
                                capiRevision )
      helper.render( fd, fmt )
   return modelSchema

def _mergeAdjacencySegments( result, systemId, instanceId, vrfName, entry,
                             srReachability ):
   # getAdjSidOrigin returns None when the SID is from another instance.
   # If the SID is from the same instance getAdjSidOrigin will return
   # one of the following value.
   # 1. configued (statically configured adjacency segment)
   # 2. remote (remote adjacency segment)
   # 3. dynamic (dynamically allocated adjacency segment)

   sidOrigin = getAdjSidOrigin( entry, systemId, instanceId )
   if sidOrigin == "remote":
      for ngbRtrid in entry.ngbRtrid:
         item = RoutingIsisShowSegmentRoutingModel \
                     .ReceivedGlobalAdjacencySegmentModel()
         item.sid = entry.sid.index
         item.systemId = entry.rtrid.stringValue
         item.ngbSystemId = ngbRtrid.stringValue
         item.hostname = _routingIsisCli.getHostName( entry.rtrid, vrfName )
         item.ngbHostname = _routingIsisCli.getHostName( ngbRtrid, vrfName )
         item.protection = getProtectionMode( entry.protectionMode )
         if entry.protectionSrlg != 'srlgNone':
            item.srlgProtection = getSrlgProtection( entry.protectionSrlg )
         item.flags = RoutingIsisShowSegmentRoutingModel \
             .AdjacencySegmentFlagsModel()
         item.flags.f = entry.flags.isIpv6
         item.flags.b = entry.flags.backup
         item.flags.v = entry.flags.isValue
         item.flags.l = entry.flags.isLocal
         item.flags.s = entry.flags.isSet
         item.srReachability = srReachability

         result.receivedGlobalAdjacencySegments.append(item)
   elif sidOrigin is not None:
      # SID is either configured or dynamically allocated
      for nexthop in entry.nexthop:
         item = RoutingIsisShowSegmentRoutingModel.AdjacencySegmentModel()
         item.ipAddress = str( nexthop.hop )
         item.localIntf = str( nexthop.intfId )
         item.sid = entry.sid.index
         item.sidOrigin = sidOrigin
         item.lan = entry.isLan
         item.level = entry.level
         item.protection = getProtectionMode( entry.protectionMode )
         if entry.protectionSrlg != 'srlgNone':
            item.srlgProtection = getSrlgProtection( entry.protectionSrlg )
         item.flags = RoutingIsisShowSegmentRoutingModel \
             .AdjacencySegmentFlagsModel()
         item.flags.f = entry.flags.isIpv6
         item.flags.b = entry.flags.backup
         item.flags.v = entry.flags.isValue
         item.flags.l = entry.flags.isLocal
         item.flags.s = entry.flags.isSet

         result.adjacencySegments.append( item )

def _populateAdjacencySegments( result, srSysdbStatus, systemId, instanceId,
                                vrfName ):
   # Display entries in sorted fashion
   adjacencySegments = sorted( srSysdbStatus.adjacencySegment.items() )
   for _, entry in adjacencySegments:
      if entry.fecId == 0:
         srReachability = "srUnreachable"
      elif entry.partialReach:
         srReachability = "srPartiallyReachable"
      else:
         srReachability = "srReachable"
      _mergeAdjacencySegments( result, systemId, instanceId, vrfName, entry,
                               srReachability )

def __addMisconfiguredAdjSidEntry( misconfiguredAdjacencySegments, intf, af, label,
                                   cfg, reason ):
   item = RoutingIsisShowSegmentRoutingModel.MisconfiguredAdjacencySegmentModel()
   item.localIntf = intf
   item.sid = label
   item.af = af
   item.reason = reason
   misconfiguredAdjacencySegments.append(item)

def _validateStaticAdjSidsConfiguration( mode, result, instStatus ):
   intfConfig = _getIsisIntfConfig()
   if not instStatus or not intfConfig:
      return
   _base = instStatus.srSrlbBase
   _range = instStatus.srSrlbRange
   outOfRangeLabelPresent = False
   interfaceTypeMismatch = False
   for intf, cfg in intfConfig.items():
      for adjSegKey in cfg.srSingleAdjacencySegment:
         af = 'ipv6' if adjSegKey.isIpV6 else 'ipv4'
         label = adjSegKey.sid.index
         if adjSegKey.sid.isValue is True:
            if label < _base or label >= _base + _range:
               #Out of range
               outOfRangeLabelPresent = True
               __addMisconfiguredAdjSidEntry( result.misconfiguredAdjacencySegments,
                                              intf, af, label, cfg, 'out-of-range' )
         if cfg.interfaceType == _interfaceTypeEnum.interfaceTypeLan:
            #For now we support static Adj SID only for p2p
            __addMisconfiguredAdjSidEntry( result.misconfiguredAdjacencySegments,
                                           intf, af, label, cfg,
                                           'interface-type-mismatch' )
            interfaceTypeMismatch = True

      for adjSegKey in cfg.srMultipleAdjacencySegment:
         if adjSegKey in cfg.srSingleAdjacencySegment:
            # if it's also present in the srSingleAdjacencySegment collection
            # then there is no need to validate it again
            continue
         af = 'ipv6' if adjSegKey.isIpV6 else 'ipv4'
         label = adjSegKey.sid.index
         if adjSegKey.sid.isValue is True:
            if label < _base or label >= _base + _range:
               #Out of range
               outOfRangeLabelPresent = True
               __addMisconfiguredAdjSidEntry( result.misconfiguredAdjacencySegments,
                                              intf, af, label, cfg, 'out-of-range' )
         if cfg.interfaceType == _interfaceTypeEnum.interfaceTypeLan:
            __addMisconfiguredAdjSidEntry( result.misconfiguredAdjacencySegments,
                                           intf, af, label, cfg,
                                          'interface-type-mismatch' )
            interfaceTypeMismatch = True

   if outOfRangeLabelPresent:
      mode.addWarning( "Some of configured Adj-SIDs are not within SRLB "\
         "Range [%s, %s]" % ( _base, _base + _range - 1 ) )

   if interfaceTypeMismatch:
      mode.addWarning( "Static Adj-SID is supported only on P2P interface" )

#------------------------------------------------------------------------------------
# The "show isis segment-routing adjacency-segments" command.
#------------------------------------------------------------------------------------
def getSegmentRoutingsAdjacencySegmentsSummaryModel(mode, cmdPrefix=None,
                                                    instanceId=None,
                                                    instanceName=None, vrfName=None):
   
   instConfig, instStatus = _getIsisConfigAndStatus(mode, instanceId, instanceName)

   if not instConfig or not instStatus:
      return None
   
   result = RoutingIsisShowSegmentRoutingModel.AdjacencySegmentSummaryModel()
  
   # pylint: disable-msg=W0612
   with _SrGuardMounts() as ( srSmashStatus, srSysdbStatus ):
      if not srSysdbStatus:
         return None
      _populateSrCommonModel( result, instConfig, instanceName, srSysdbStatus )
      _populateAdjacencyValues( result, instConfig, instStatus )
      _populateAdjacencySegments( result, srSysdbStatus, instConfig.systemId,
                                  instConfig.instanceId, vrfName )
      _validateStaticAdjSidsConfiguration( mode, result, instStatus )
   # pylint: enable-msg=W0612
   
   return result

def ShowIsisSegmentRoutingCmdHandler( mode, args ):

   del args[ 'segment-routing' ]
   args[ 'instDictModelType' ] = RoutingIsisShowSegmentRoutingModel.\
                                 SegmentRoutingsSummaryModel
   args[ 'cmdVrfModel' ] = RoutingIsisShowSegmentRoutingModel.\
                           SegmentRoutingsSummaryVRFsModel
   return showIsisCommand( mode, getSegmentRoutingsSummaryModel, **args )

def ShowIsisSegmentRoutingGlobalBlocksCmdHandler( mode, args ):

   del args[ 'segment-routing' ]
   del args[ 'global-blocks' ]
   args[ 'instDictModelType' ] = RoutingIsisShowSegmentRoutingModel.\
                                 SegmentRoutingsGlobalBlocksSummaryModel
   args[ 'cmdVrfModel' ] = RoutingIsisShowSegmentRoutingModel.\
                           SegmentRoutingsGlobalBlocksSummaryVRFsModel
   return showIsisCommand( mode, getSegmentRoutingsGlobalBlocksSummaryModel,
                           **args )

def ShowIsisSegmentRoutingPrefixSegmentCmdHandler( mode, args ):

   del args[ 'segment-routing' ]
   del args[ 'prefix-segments' ]
   args[ 'selfOrig' ] = bool( args.pop( 'self-originated', None ) )
   if 'ipv4' in args:
      args[ 'afSegments' ] = _addrFamilyEnum.ipv4
      del args[ 'ipv4' ]
   elif 'ipv6' in args:
      args[ 'afSegments' ] = _addrFamilyEnum.ipv6
      del args[ 'ipv6' ]
   args[ 'instDictModelType' ] = RoutingIsisShowSegmentRoutingModel.\
                                 SegmentRoutingsPrefixSegmentsSummaryModel
   args[ 'cmdVrfModel' ] = RoutingIsisShowSegmentRoutingModel.\
                           SegmentRoutingsPrefixSegmentsSummaryVRFsModel
   args[ 'useCliPrint' ] = True
   return showIsisCommand( mode, getSegmentRoutingsPrefixSegmentsSummaryModel,
                           **args )

def ShowIsisSegmentRoutingAdjSegmentCmdHandler( mode, args ):

   del args[ 'segment-routing' ]
   del args[ 'adjacency-segments' ]

   args[ 'instDictModelType' ] = RoutingIsisShowSegmentRoutingModel.\
                                 SegmentRoutingsAdjacencySegmentsSummaryModel
   args[ 'cmdVrfModel' ] = RoutingIsisShowSegmentRoutingModel.\
                           SegmentRoutingsAdjacencySegmentsSummaryVRFsModel
   return showIsisCommand( mode, getSegmentRoutingsAdjacencySegmentsSummaryModel,
                           **args )

#-------------------------------------------------------------------------------
# show isis segment-routing tunnel [ tunnel-index | tep ]
#-------------------------------------------------------------------------------
def addrMatchTep( endpoint=None , tunnel=None ):
   if not endpoint:
      return True 
   if not tunnel:
      return False
   if endpoint == tunnel.tep.stringValue:
      return True
   return False

def getIsisSrTunnelTableEntryModel( tunnelId ):
   vias = []
   tunnelTableEntry = srTunnelTable.entry.get( tunnelId )

   if tunnelTableEntry:
      for via in tunnelTableEntry.via.values():
         viaModel = None

         labels = []
         for mplsStackIndex in reversed ( range( via.labels.stackSize ) ):
            labels.append( str( via.labels.labelStack( mplsStackIndex ) ) )

         if FecIdIntfId.isFecIdIntfId( via.intfId ):
            fec = getFecFromIntfId( via.intfId )
            if fec is None:
               continue

            for fecVia in sorted( fec.via.values() ):
               addr = Tac.Value( "Arnet::IpGenAddr", str( fecVia.hop ) )
               intfId = fecVia.intfId
               viaModel = getViaModel( addr, intfId, labels )
               if viaModel:
                  vias.append( viaModel )
         else:
            viaModel = getViaModel( via.nexthop, via.intfId, labels )
            if viaModel:
               vias.append( viaModel )

      # Make via rendering order deterministic
      vias.sort( key=sortKey )
      return RoutingIsisShowSegmentRoutingModel.IsisSrTunnelTableEntry(
         endpoint=tunnelTableEntry.tep, vias=vias )
   return None

def showIsisSrTunnelTable( tunnelIndex=None, endpoint=None ):
   entries = {}
   if tunnelIndex is None:
      for tunnelId, tunnel in srTunnelTable.entry.items():
         if addrMatchTep( endpoint=endpoint, tunnel=tunnel ):
            entryModel = getIsisSrTunnelTableEntryModel( tunnelId=tunnelId )
            if entryModel:
               entries[ getTunnelIndexFromId( tunnelId ) ] = entryModel
   else:
      tunnel, tunnelId = getTunnelEntryAndIdFromIndexAndAf(
         tunnelIndex, srTunnelTable,
         RoutingIsisShowSegmentRoutingModel.IsisSrTunnelTable )
      if tunnel:
         if addrMatchTep( endpoint=endpoint, tunnel=tunnel ):
            entryModel = getIsisSrTunnelTableEntryModel( tunnelId=tunnelId )
            if entryModel:
               entries[ getTunnelIndexFromId( tunnelId ) ] = entryModel
   return RoutingIsisShowSegmentRoutingModel.IsisSrTunnelTable( entries=entries )

endpointHelpString = 'Match this endpoint prefix'

def ShowIsisSegmentRoutingTunnelCmdHandler( mode, args ):

   endpoint = args.pop( 'IPADDR', None ) or args.pop( 'IP6ADDR', None )
   tunnelIndex = args.pop( 'TUNNEL-INDEX', None )
   if endpoint:
      endpoint = str( endpoint )

   return showIsisSrTunnelTable( tunnelIndex=tunnelIndex, endpoint=endpoint )
