# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import re
import functools
import Tac
import CliParser
import LazyMount

from CliDynamicSymbol import CliDynamicPlugin

# Importing module instead of global variables as importing
# global variables objects can give incorrect values.
# When we import an object from another package, we copy a
# reference to the object into our package
# By setting the value in the other package, we overwrite the
# key in that package with a new value, but we leave our
# reference pointing to the old one. Python recommendation link:
# https://docs.python.org/3/faq/programming.html#how-do-i-share-
# global-variables-across-modules
from CliPlugin import ControllerCli
from CliPlugin import NetworkTopologyService

from CliPlugin.NetworkTopologyModels import findHostname
from CliPlugin.NetworkTopologyModels import DirectedEdge
from CliPlugin.NetworkTopologyModels import Host
from CliPlugin.NetworkTopologyModels import LogicalPort
from CliPlugin.NetworkTopologyModels import Port
from CliPlugin.NetworkTopologyModels import PortGroup
from CliPlugin.NetworkTopologyModels import TopologyHosts
from CliPlugin.NetworkTopologyModels import TopologyNeighbors

TopologyDynamicSubmodes = CliDynamicPlugin( "NetworkTopologyService" )

def forceMountWithActivityLock( func ):
   @functools.wraps( func )
   def _withActivityLock( *args, **kwargs ):

      LazyMount.force( NetworkTopologyService.status )

      with Tac.ActivityLockHolder():
         return func( *args, **kwargs )

   return _withActivityLock

def gotoServiceTopology( mode, args ):
   childMode = mode.childMode( TopologyDynamicSubmodes.TopologyConfigMode )
   mode.session_.gotoChildMode( childMode )

def clearServiceTopology( mode, args=None ): # NoCVX hook only passes in `mode`.
   NetworkTopologyService.config.staticEdge.clear()

ControllerCli.addNoCvxCallback( clearServiceTopology )

# Pattern to verify physical interface for fixed and modular switches
switchIntfPattern = r'(^[Et|et][a-zA-Z]+)((\d+)|(\d+\/\d+)|(\d+\/\d+\/\d+))$'

def getSwitchIntfFullName( intf ):
   ethPhyType = 'ethernet'
   match = re.match( switchIntfPattern, intf )
   switchIntf = None
   if match:
      intfType = match.groups()[ 0 ]
      intfId = match.groups()[ 1 ]
      if ethPhyType.startswith( intfType.lower() ):
         switchIntf = ethPhyType.title() + intfId
   return switchIntf

def networkPhysicalTopologyConfig( mode, args ):
   neighIntf = args.get( 'NEIGHBOR-INTERFACE', "" )
   switchIntf = getSwitchIntfFullName( args[ 'INTERFACE' ] )
   if switchIntf is None:
      raise CliParser.InvalidInputError()
   edge = Tac.newInstance( "NetworkTopologyAggregatorV3::StaticEdge",
         args[ 'SWITCH' ], switchIntf, args[ 'NEIGHBOR-HOST' ],
         neighIntf )
   NetworkTopologyService.config.staticEdge[ edge ] = True

def networkPhysicalTopologyConfigNoOrDef( mode, args ):
   neighIntf = args.get( 'NEIGHBOR-INTERFACE', "" )
   switchIntf = getSwitchIntfFullName( args[ 'INTERFACE' ] )
   if switchIntf is None:
      raise CliParser.InvalidInputError()
   edge = Tac.newInstance( "NetworkTopologyAggregatorV3::StaticEdge",
         args[ 'SWITCH' ], switchIntf, args[ 'NEIGHBOR-HOST' ],
         neighIntf )
   del NetworkTopologyService.config.staticEdge[ edge ]

def showEdgesHelper( edgeKey, edge ):
   e = DirectedEdge()
   portName = edgeKey.name
   hostname = findHostname( edgeKey.host() )
   e.fromPort = Port( name=portName, hostname=hostname,
                      hostid=edgeKey.host().name,
                      portChannel='' )
   for connectedPort in edge.toPort:
      toPort = Port( name=connectedPort.name,
                     hostid=connectedPort.host().name,
                     hostname=findHostname( connectedPort.host() ) )
      toPort.portChannel = ''
      e.toPort.append( toPort )
   return portName, hostname, e

@forceMountWithActivityLock
def showNeighbors( mode, args ):
   host = args.get( 'HOSTNAME' )
   topoNeighbors = TopologyNeighbors()

   for edgeKey, edge in NetworkTopologyService.status.edge.items():
      if host and findHostname( edgeKey.host() ) != host:
         continue
      portName, hostname, e = showEdgesHelper( edgeKey, edge )
      # If portName doesn't exist, just use the hostname
      if portName:
         neighborKey = f'{hostname}-{portName}'
      else:
         neighborKey = hostname
      topoNeighbors.neighbors[ neighborKey ] = e

   return topoNeighbors

@forceMountWithActivityLock
def showHosts( mode, args ):
   name = args.get( 'HOSTNAME' )
   details = 'details' in args
   topoHosts = TopologyHosts()

   hosts = []
   if name:
      for h in NetworkTopologyService.status.host.values():
         if findHostname( h, useName=False ) == name:
            hosts.append( h )
   else:
      hosts = list( NetworkTopologyService.status.host.values() )

   for host in hosts:
      h = Host( name=host.name, hostname=findHostname( host, useName=False ) )
      h.details = details
      for port in host.port.values():
         portName = port.name
         h.port[ portName ] = Port( name=portName,
                                    hostid=host.name,
                                    hostname=findHostname( host ) )

         h.port[ portName ].portChannel = ""
         if port.portGroup:
            portGroupName = port.portGroup.name
            h.port[ portName ].portChannel = portGroupName

            h.portChannels[ portGroupName ] = PortGroup( name=portGroupName,
                                                         memberPorts=[] )
            for memberPort in port.portGroup.memberPort.values():
               h.portChannels[ portGroupName ].memberPorts.append(
                  memberPort.name )

      for logicalPort in host.logicalPort.values():
         portName = logicalPort.name
         h.logicalPorts[ portName ] = LogicalPort( name=portName )
         if logicalPort.type == 'portGroupType':
            for memberPort in logicalPort.memberPort.values():
               h.logicalPorts[ portName ].memberPorts.append(
                  memberPort.name )
      topoHosts.hosts[ host.name ] = h

   return topoHosts
