#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from TypeFuture import TacLazyType
import Tac
from XcvrLib import ( ConfigToModelMapping,
                      getXcvrSlotName,
                      getXcvrStatus,
                      internalLaneIdToCli,
                      isPrimaryIntf,
                      selectCollection )
from CliPlugin.XcvrCmisLoopbackCli import IntfRangeKW
from CliPlugin.XcvrCli import getAllIntfsWrapper
from CliPlugin.XcvrConfigCli import ( getXcvrConfigCliDir,
                                      getXcvrConfigCliForConfigCommand,
                                      xcvrCmisGuardFn )

import CliPlugin
import collections
import functools
from typing import Any, Callable
from collections.abc import Iterable

from CliDynamicSymbol import CliDynamicPlugin
XcvrCmisLoopbackModel = CliDynamicPlugin( "XcvrCmisLoopbackModel" )

Capabilities = XcvrCmisLoopbackModel.Capabilities
Capability = XcvrCmisLoopbackModel.Capability
IntfCapabilities = XcvrCmisLoopbackModel.IntfCapabilities
LaneInfo = XcvrCmisLoopbackModel.LaneInfo
SlotInfo = XcvrCmisLoopbackModel.SlotInfo
TrafficLoopback = XcvrCmisLoopbackModel.TrafficLoopback
createCapabilitiesModel = XcvrCmisLoopbackModel.createCapabilitiesModel
notSupportedStr = XcvrCmisLoopbackModel.notSupportedStr

CmisLoopbackMode = TacLazyType( 'Xcvr::CmisLoopback::Mode' )
XcvrPresence = TacLazyType( "Xcvr::XcvrPresence" )

xcvrStatusDir = None

Descr = collections.namedtuple( 'Descr', ( 'kind', 'side' ) )

capabilityMapping = {
   'perLaneMediaSideLoopback': Descr( 'perLane', 'media' ),
   'perLaneHostSideLoopback': Descr( 'perLane', 'host' ),
   'hostSideInputLoopback': Descr( 'system', 'host' ),
   'hostSideOutputLoopback': Descr( 'network', 'host' ),
   'mediaSideInputLoopback': Descr( 'network', 'media' ),
   'mediaSideOutputLoopback': Descr( 'system', 'media' ),
}

defaultLaneCnt = 8

def xcvrStatusForPrimary( intfName: str ):
   slot = getXcvrSlotName( intfName )
   if not isPrimaryIntf( slot ):
      return None
   return getXcvrStatus( xcvrStatusDir.xcvrStatus.get( slot ) )

def getAllIntfsWrapperBetter( *args, **kwargs ):
   a, b = getAllIntfsWrapper( *args, **kwargs )
   return ( a or [], b or [] )

def gerPrimaryIntfData( mode, args ):
   return ( ( intf, intfName, status ) for intf, intfName in zip(
      *getAllIntfsWrapperBetter( mode, args.get( IntfRangeKW ), args.get( 'MOD' ) ) )
            if ( status := xcvrStatusForPrimary( intfName ) )
            and xcvrCmisGuardFn( intf ) is None )

def setCapabilities( model: Capabilities, caps ):
   if caps is not None:
      model.simultaneousHostMedia = caps.simultaneousHostAndMediaLoopback == 1
      model.capabilities = [
         Capability( kind=target.kind,
                     side=target.side,
                     supported=( getattr( caps, key ) == 1 ) )
         for key, target in capabilityMapping.items()
      ]
   return model

def getLoopbackCaps( status ):
   eepromContents = getattr( status, "eepromContents", None )
   if eepromContents and eepromContents.testPatternContents:
      return Tac.Value( "Xcvr::PrbsLoopbackCapabilities",
                        eepromContents.testPatternContents.prbsLoopbackCaps )
   return None

def shouldOmitStatusFromModel( status, xcvrConfigCli, si: SlotInfo = None ) -> bool:
   '''
   Omits the data from the model if:
     - either the module is not present and there's no config Cli (or CMIS loopback)
     - or module is not present, and everything is explicitly set to NoLoopback
   '''
   return status.presence != XcvrPresence.xcvrPresent and (
      ( not xcvrConfigCli or not xcvrConfigCli.cmisLoopbackConfig ) or
      ( si is not None and allLanesAreNone( si, 'host' )
        and allLanesAreNone( si, 'media' ) ) )

def setLanesModel( status, xcvrConfigCli, si: SlotInfo ):
   hostCfgAll, mediaCfgAll = getAllLanesFromConfig( xcvrConfigCli )
   si.hostLanes = createLanesModel( status, xcvrConfigCli, 'host', hostCfgAll )
   si.mediaLanes = createLanesModel( status, xcvrConfigCli, 'media', mediaCfgAll )

def getAllLanesFromConfig( xcvrConfigCli ):
   if xcvrConfigCli and xcvrConfigCli.cmisLoopbackConfig:
      cfg = xcvrConfigCli.cmisLoopbackConfig
      return ( cfg.allHostLanesConfigured, cfg.allMediaLanesConfigured )
   return ( False, False )

def setAllLanesFromConfig( xcvrConfigCli, si: SlotInfo ):
   host, media = getAllLanesFromConfig( xcvrConfigCli )
   # pylint: disable=protected-access
   si._hostLanesAllCfg = host
   si._mediaLanesAllCfg = media
   # pylint: enable=protected-access

def isLaneInfoNone( laneInfo: LaneInfo ) -> bool:
   return laneInfo.configured == 'none' and ( laneInfo.operational
                                              in ( 'none', 'notPresent' ) )

def laneInfoMatch( lhs: LaneInfo, rhs: LaneInfo ) -> bool:
   return lhs.configured == rhs.configured and lhs.operational == rhs.operational

def allLanesAreNone( si, side: str, key='Lanes' ):
   coll = getattr( si, f'{side}{key}' )
   allNone = all( isLaneInfoNone( v ) for v in coll.values() )
   return allNone

def adjustAllLanesConfigured( si: SlotInfo, side: str ):
   if allLanesAreNone( si, side ):
      setModelAllLanes( si, side, True )
      return

   if getattr( si, f'_{side}LanesAllCfg' ):
      coll = getattr( si, f'{side}Lanes' )
      items = list( coll.values() )
      homogeneousLaneCfg = all( laneInfoMatch( v, items[ 0 ] ) for v in items )
      setModelAllLanes( si, side, homogeneousLaneCfg )

def setModelAllLanes( si: SlotInfo, side: str, homogeneousLaneCfg: bool ):
   setattr( si, f'_{side}LanesAllCfg', homogeneousLaneCfg )

def getConfigStatus( xcvrConfig, side, lane ):
   if not xcvrConfig or not xcvrConfig.cmisLoopbackConfig:
      return 'none'
   coll = selectCollection( xcvrConfig.cmisLoopbackConfig, side )
   if lane not in coll:
      return 'none'
   return ConfigToModelMapping[ side ][ coll[ lane ] ]

def getOpStatus( configured: str, xcvrStatus, side: str, lane: int ):
   if xcvrStatus is None or xcvrStatus.presence != XcvrPresence.xcvrPresent:
      return 'notPresent'
   if xcvrStatus.cmisLoopbackStatus is None:
      return 'none'
   coll = selectCollection( xcvrStatus.cmisLoopbackStatus, side )
   if lane not in coll:
      status = 'notPresent' if lane >= getOpLaneCount( xcvrStatus, side ) else 'none'
   else:
      status = ConfigToModelMapping[ side ][ coll[ lane ] ]
   if status != 'notPresent' and ( configured not in ( status, 'none' ) ):
      return 'notSupported'
   return status

def getOpLaneCount( status, side: str, forceDefault=False ) -> int:
   if status is None:
      return defaultLaneCnt
   if side == 'host' or status.presence != XcvrPresence.xcvrPresent or forceDefault:
      # Use the host side channel number
      # if module is not present and requesting media side count
      return status.capabilities.maxChannels
   return status.capabilities.lineSideChannels

def getConfLaneCount( config, side: str ) -> int:
   if not config or not config.cmisLoopbackConfig:
      return -1
   coll = selectCollection( config.cmisLoopbackConfig, side )
   maxKey = max( coll.keys(), default=-1 )
   return maxKey + 1 if maxKey >= 0 else 0

def createLanesModel( status, xcvrConfigCli, side: str, allLanes: bool ):
   opLaneCount = getOpLaneCount( status, side )
   confLaneCount = getConfLaneCount( xcvrConfigCli, side )
   # NOTE: if the module is present, and all Lanes were configured, only consider
   # the real number of lanes to omit `not present` noise from the output
   maxLaneCount = max(
      ( opLaneCount, confLaneCount ) ) if not allLanes else opLaneCount
   return {
      internalLaneIdToCli( lane ):
      LaneInfo( configured=( configured := getConfigStatus( xcvrConfigCli, side,
                                                            lane ) ),
                operational=getOpStatus( configured, status, side, lane ) )
      for lane in range( 0, maxLaneCount )
   }

def getLanes( xcvrStatus, side, lanes ):
   if not lanes:
      laneCount = getOpLaneCount( xcvrStatus, side, forceDefault=True )
      lanes = range( 1, laneCount + 1 )
   return cliLaneListIdToInternal( lanes )

def cliLaneListIdToInternal( lanes ):
   return [ i - 1 for i in lanes ]

def getLoopbackCliConfig( intfName: str, xcvrConfig, create=True ):
   if not xcvrConfig:
      assert not create
      return None
   if xcvrConfig.cmisLoopbackConfig is None and create:
      xcvrConfig.cmisLoopbackConfig = ( intfName, )
   return xcvrConfig.cmisLoopbackConfig

def setAllLanes( isAllLanesVal, loopbackConfig, isHost ):
   location = "Host" if isHost else "Media"
   attrName = f'all{location}LanesConfigured'
   setattr( loopbackConfig, attrName, isAllLanesVal )

def isAllLanes( lanes, isHost, status ):
   if isHost:
      refLaneCnt = ( status.capabilities.maxChannels
                     if status is not None else defaultLaneCnt )
      refLanes = set( range( 1, refLaneCnt + 1 ) )
      # NOTE: the CLI accepts any lanes between the interval [1,8] for any
      # cmis Slot,even for those with lane count 4
      return lanes is None or refLanes.issubset( set( lanes ) )
   # Since the number of media lanes can change based on the actual module plugged in
   # we only consider all lanes , if the lanes argument was omitted
   # from the command
   return lanes is None

def setLoopback( collection, loopback, lanes: Iterable[ int ] ):
   for idx in lanes:
      collection[ idx ] = loopback

def commitConfigUpdate( loopbackCfg ):
   cleanupNoLoopbackEntries( loopbackCfg )
   loopbackCfg.generationId += 1

def getCliConfig( mode, create=True ):
   intfName = mode.intf.name
   return getCliConfigByName( intfName, create )

def getCliConfigByName( intfName: str, create=True ):
   xcvrConfigCliDir = getXcvrConfigCliDir( intfName )
   xcvrConfigCli = getXcvrConfigCliForConfigCommand( intfName, xcvrConfigCliDir,
                                                     create )

   return intfName, xcvrConfigCli

def sourceSystemLoopbackType( isHost: bool ):
   # source system device host => CMIS Host Input
   # source system device media => CMIS Media Output
   host, media = CmisLoopbackMode.Input, CmisLoopbackMode.Output
   return host if isHost is True else media

def sourceNetworkLoopbackType( isHost: bool ):
   # source network device host => CMIS Host Output
   # source network device media => CMIS Media Input
   host, media = CmisLoopbackMode.Output, CmisLoopbackMode.Input
   return host if isHost is True else media

##########
# Config delete logic
##########
def deleteConfigBothSide( loopbackConfig, loopbackTypeFunc: Callable[ [ bool ],
                                                                      Any ] ):
   removed = deleteConfigSingleSideImpl( loopbackConfig, True, loopbackTypeFunc,
                                         deleteLoopbackCfgFromColl )
   if deleteConfigSingleSideImpl( loopbackConfig, False, loopbackTypeFunc,
                                  deleteLoopbackCfgFromColl ) or removed:
      commitConfigUpdate( loopbackConfig )

def deleteConfigSingleSide( loopbackConfig, isHost, loopbackTypeFunc ):
   if deleteConfigSingleSideImpl( loopbackConfig, isHost, loopbackTypeFunc,
                                  deleteLoopbackCfgFromColl ):
      commitConfigUpdate( loopbackConfig )

def deleteConfigSingleSideOnLanes( loopbackConfig, isHost, loopbackTypeFunc, lanes ):
   if deleteConfigSingleSideImpl(
         loopbackConfig, isHost, loopbackTypeFunc,
         functools.partial( deleteLoopbackCfgFromCollOnLanes, lanes=lanes ) ):
      commitConfigUpdate( loopbackConfig )

def deleteConfigSingleSideImpl( loopbackConfig, isHost, loopbackTypeFunc, deleter ):
   '''
   Implements deletion while keeping track if any loopback was removed from the
   config, and setting the corresponding all lanes configured attribute accordingly

   Parameters
   ----------
   loopbackConfig : Entity
      CMIS Loopback config entity
   isHost : bool
      True if the config is deleted on the host side of the module
   loopbackTypeFunc : Callable[[bool],LoopbackMode]
      returns the loopback mode for the given side considering the source of the
      traffic
   deleter: Callable[[dict,LoopbackMode]]
      executes the removal logic

   Returns
   ----------
   removed (bool) : True if any kind of removal occurred
   None
   '''
   coll = selectCollection( loopbackConfig, isHost )
   size = len( coll )
   deleter( coll, loopbackTypeFunc( isHost ) )
   removed = len( coll ) != size
   if removed:
      setAllLanes( False, loopbackConfig, isHost )
   return removed

def cleanupNoLoopbackEntries( loopbackConfig ):
   deleteLoopbackCfgFromColl( loopbackConfig.mediaLoopback,
                              CmisLoopbackMode.NoLoopback )
   deleteLoopbackCfgFromColl( loopbackConfig.hostLoopback,
                              CmisLoopbackMode.NoLoopback )

def deleteLoopbackCfgFromCollOnLanes( loopbackConfigColl, loopbackType, lanes ):
   for lane in (
         l for l in lanes
         if l in loopbackConfigColl and loopbackConfigColl[ l ] == loopbackType ):
      del loopbackConfigColl[ lane ]

def deleteLoopbackCfgFromColl( loopbackConfigColl, loopbackType ):
   for lane in ( lane for lane, lbtype in loopbackConfigColl.items()
                 if lbtype == loopbackType ):
      del loopbackConfigColl[ lane ]

#######
# Handler implementations
#######
def CmisLoopbackStatusShowHandler( mode, args ):
   model = TrafficLoopback()
   for _, intfName, status in gerPrimaryIntfData( mode, args ):
      intfName, xcvrConfigCli = getCliConfigByName( intfName, create=False )
      assert status is not None
      si = SlotInfo()
      if not shouldOmitStatusFromModel( status, xcvrConfigCli ):
         setLanesModel( status, xcvrConfigCli, si )
         setAllLanesFromConfig( xcvrConfigCli, si )
         for side in ( 'host', 'media' ):
            adjustAllLanesConfigured( si, side )
         if not shouldOmitStatusFromModel( status, xcvrConfigCli, si ):
            model.transceiverLoopback[ intfName ] = si
   return model

def CommonConfigHandler( mode, args, loopbackTypeFunc: Callable[ [ bool ], Any ] ):
   intfName, xcvrConfigCli = getCliConfig( mode )
   loopbackCfg = getLoopbackCliConfig( intfName, xcvrConfigCli )
   assert loopbackCfg
   status = xcvrStatusForPrimary( intfName )
   isHost = 'host' in args
   lanes = args.get( 'LANES' )
   # first capture if config command is configuring all lanes
   setAllLanes( isAllLanes( lanes, isHost, status ), loopbackCfg, isHost )
   # set the corresponding loopback mode in the config collection on all
   # the lanes specified
   setLoopback( selectCollection( loopbackCfg, isHost ), loopbackTypeFunc( isHost ),
                getLanes( status, 'host' if isHost else 'media', lanes ) )
   # maintain invariants (cleanup + increment generation)
   commitConfigUpdate( loopbackCfg )

def CommonConfigNoHandler( mode, args, loopbackTypeFunc: Callable[ [ bool ], Any ] ):
   intfName, cliConfig = getCliConfig( mode, create=False )
   loopbackCfg = getLoopbackCliConfig( intfName, cliConfig, create=False )
   if not loopbackCfg:
      return
   isHost = 'host' in args
   isMedia = 'media' in args
   lanes = args.get( 'LANES' )
   deleteAll = not isHost and not isMedia
   # NOTE: based on the CLI syntax definition only
   # one of them can be present at a time
   assert not ( isHost and isMedia )
   if deleteAll:
      deleteConfigBothSide( loopbackCfg, loopbackTypeFunc )
   elif lanes is None:
      deleteConfigSingleSide( loopbackCfg, isHost, loopbackTypeFunc )
   else:
      deleteConfigSingleSideOnLanes( loopbackCfg, isHost, loopbackTypeFunc,
                                     cliLaneListIdToInternal( lanes ) )

############
# Entry points from CLI Plugin
############
def CmisLoopbackCapabilitiesShowHandler( mode, args ):
   caps = IntfCapabilities()
   for _, intfName, status in gerPrimaryIntfData( mode, args ):
      caps.loopbackCaps[ intfName ] = setCapabilities( createCapabilitiesModel(),
                                                       getLoopbackCaps( status ) )
   return caps

def CmisLoopbackSystemConfigCommandHandler( mode, args ):
   CommonConfigHandler( mode, args, sourceSystemLoopbackType )

def CmisLoopbackSystemConfigCommandNoHandler( mode, args ):
   CommonConfigNoHandler( mode, args, sourceSystemLoopbackType )

def CmisLoopbackNetworkConfigCommandHandler( mode, args ):
   CommonConfigHandler( mode, args, sourceNetworkLoopbackType )

def CmisLoopbackNetworkConfigCommandNoHandler( mode, args ):
   CommonConfigNoHandler( mode, args, sourceNetworkLoopbackType )

# ------------------------------------------------------
# Plugin method
# ------------------------------------------------------
def Plugin( em ):
   global xcvrStatusDir
   xcvrStatusDir = CliPlugin.XcvrAllStatusDir.xcvrAllStatusDir( em )
