#!/usr/bin/env python3
# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
import Tracing
from IpLibConsts import DEFAULT_VRF
import SharedMem
import Shark
import Arnet

handle = Tracing.Handle( "MfibSharkTestLib" )

t0 = handle.trace0
t1 = handle.trace1

FORWARD = 1
ACCEPT = 2
NOTIFY = 4

RouteReadyEnum = Tac.Type( "Shark::Multicast::Fib::RouteReady" )
MfibVrfName = Tac.Type( "Shark::Multicast::Fib::MfibVrfName" )

class MfibSharkTestEnv:
   _id = 0

   def __init__( self, mfibShark=None, entityManager=None, vrfName=DEFAULT_VRF,
                 appName="sparsemode", af="ipv4" ):
      self.em = entityManager
      self.vrfName = vrfName
      self.appName = appName
      self.af = af

      self.shmemEm = SharedMem.entityManager( sysdbEm=self.em ) if self.em else None
      self.typeStr = "Shark::Multicast::Fib::StatusByVrf"
      self.path = ""
      self.appKey = "%s_%s" % ( self.appName, 'v4' if self.af == 'ipv4' else 'v6' )

      if mfibShark:
         self.mfibShark = mfibShark
      elif self.shmemEm is not None:
         t1( "app: %s, vrf: %s, Mounting MfibShark writer" % ( appName, vrfName ) )
         self.writerMountInfo = Shark.mountInfo( 'writer', cleanup=True )
         self.path = Tac.Type( self.typeStr ).mountPath( self.af, self.appName )
         t1( "mount path: %s" % self.path )
         shmemMg = self.shmemEm.getMountGroup()
         self.statusByVrf = shmemMg.doMount( self.path, self.typeStr,
                                             self.writerMountInfo )
         shmemMg.doClose()
         Tac.waitFor( lambda: shmemMg.ready, description="mount to be ready" )
         t1( "Mount complete" )

         self.statusByVrf.statusInfo = ( self.statusByVrf.name, )
         self.statusByVrf.statusInfo.appType = self.appKey

         self.mfibShark = self.statusByVrf.newVrfStatus(
            MfibVrfName( self.vrfName ) )
         self.mfibShark.appName = self.appName
         self.mfibShark.appType = self.appKey
      else:
         self.statusByVrf = Tac.newInstance( self.typeStr, "Mfib" )
         self.statusByVrf.statusInfo = ( self.statusByVrf.name, )
         self.statusByVrf.statusInfo.appType = self.appKey

         self.mfibShark = self.statusByVrf.newVrfStatus(
               MfibVrfName( self.vrfName ) )
         self.mfibShark.appType = self.appKey
         self.mfibShark.appName = self.appName

      self.mfib = self.mfibShark
      MfibSharkTestEnv._id += 1

   def __del__( self ):
      # Writer unmount is not supported by Shark
      pass

   def shadowMount( self, appName ):
      t1( "app: %s, Mounting MfibShark as reader" % appName )
      mountInfo = Shark.mountInfo( 'shadow' )
      path = Tac.Type( self.typeStr ).mountPath( self.af, appName )
      t1( "mount path: %s" % path )
      shmemMg = self.shmemEm.getMountGroup()
      statusByVrf = shmemMg.doMount( path, self.typeStr, mountInfo )
      shmemMg.doClose()
      return statusByVrf

   def clear( self ):
      if self.mfibShark:
         t1( "Clearing all routes for %s" % self.mfibShark.appName )
         self.mfibShark.route.clear()

   @staticmethod
   def ipGenRouteKey( s, g ):
      hostLen = 32 if '.' in s else 128
      if s == '0.0.0.0' or s == '::': # pylint: disable=consider-using-in
         s += '/0'
      else:
         s += '/%d' % hostLen
      return Tac.Value( "Routing::Multicast::Fib::IpGenRouteKey",
                        Arnet.IpGenPrefix( s ), Arnet.IpGenPrefix( g ) )

   def iif( self, key ):
      r = self.route( key )
      if r:
         return r.iif
      else:
         return None

   def notify( self, key ):
      r = self.route( key )
      if r and r.notify:
         return r.notify
      else:
         return None

   def ptr( self ):
      return Tac.addressOf( self.mfibShark )

   def route( self, key ):
      return self.mfibShark.route.get( key )

   def routeIter( self ):
      return self.mfibShark.route.values()

   def oifs( self, key ):
      oifs = []
      for intfId in self.routeOifIter( key ):
         oifs.append( intfId )
      return oifs

   def isBlankRoute( self, key ):
      return self.oifs( key ) or self.route( key ).iif != "Null0" or \
                                 self.route( key ).notify != ""

   def newRoute( self, routeKey, *args, **kwargs ):
      t1( "Initializing new route with key %s" % routeKey )
      self.deleteRoute( routeKey )
      if routeKey:
         route = self.mfibShark.newRoute( routeKey )
      else:
         return None
      route.routeReady = RouteReadyEnum.notReady

      iif = ""
      notify = ""
      oifs = [ ]
      staticFastdrop = ""
      notifyProgrammable = False
      frr = ""
      rpaId = 0

      priority = kwargs.get( 'priority', 1 )
      programmable = kwargs.get( 'programmable', True )
      iifPmsi = kwargs.get( 'iifPmsi', False )
      toCpu = kwargs.get( 'toCpu', False )
      iifFlags = kwargs.get( 'iifFlags', None )
      routeFlags = kwargs.get( 'routeFlags', None )

      try:
         iif = args[ 0 ]
         notify = args[ 1 ]
         oifs = args[ 2 ]
         staticFastdrop = args[ 3 ]
         notifyProgrammable = args[ 4 ]
         frr = args[ 5 ]
         rpaId = args[ 6 ]
      except IndexError:
         pass

      if programmable:
         newRouteFlags = Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                                    route.routeFlags.value )
         newRouteFlags.programmable = True
         route.routeFlags = newRouteFlags

      if toCpu:
         newRouteFlags = Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                                    route.routeFlags.value )
         newRouteFlags.toCpu = True
         route.routeFlags = newRouteFlags

      if iif == "Register0":
         route.iif = "Register0"
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
               route.routeFlags.value )
         newrouteFlags.iifRegister = True
         route.routeFlags = newrouteFlags
      else:
         if not iif:
            iif = ""
         route.iif = Tac.Value( "Arnet::IntfId", iif )

      # notify:
      if not notify:
         notify = ""
         notifyProgrammable = False
      route.notify = Tac.Value( "Arnet::IntfId", notify )
      route.notifyProgrammable = notifyProgrammable
      newrouteFlags = \
            Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                        route.routeFlags.value )
      newrouteFlags.skipSoftwareProgramming = bool( iif != "Register0" )
      route.routeFlags = newrouteFlags

      # staticFastdrop:
      if not staticFastdrop:
         staticFastdrop = ""
      route.fastdrop = Tac.Value( "Arnet::IntfId", staticFastdrop )

      # frr:
      if not frr:
         frr = ""
      route.iifFrr = Tac.Value( "Arnet::IntfId", frr )

      if iifFlags:
         route.iifFlags = Tac.Value( "Routing::Multicast::Fib::IifFlags", iifFlags )

      if routeFlags:
         route.routeFlags = Tac.Value( "Routing::Multicast::Fib::RouteFlags",
               routeFlags )

      if iifPmsi:
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
               route.routeFlags.value )
         newrouteFlags.iifPmsi = True
         route.routeFlags = newrouteFlags

      if oifs:
         for intfName in oifs:
            intfName = intfName.strip()
            route = self.addOif( route, intfName )
      route.rpaId = int( rpaId )
      route.routePriority = priority
      t1( "iif=%s, notify=%s, oifs=%s" % ( iif, notify, oifs ) )
      route.routeReady = RouteReadyEnum.ready
      return route

   def deleteRoute( self, routeKey ):
      route = self.mfibShark.route.get( routeKey )
      if route:
         route.routeReady = RouteReadyEnum.notReady
      del self.mfibShark.route[ routeKey ]

   def addIntf( self, routeKey, intfName, intfFlags ):
      intfName = intfName.strip()
      route = self.route( routeKey )
      route.routeReady = RouteReadyEnum.notReady

      if isinstance( intfFlags, int ):
         intfFlags = Tac.Value( "Routing::Multicast::Fib::IntfFlags", intfFlags )

      if intfFlags.forward:
         route = self.addOif( route, intfName )
      elif intfFlags.notify:
         route.notify = Tac.Value( "Arnet::IntfId", intfName )
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                          route.routeFlags.value )
         newrouteFlags.skipSoftwareProgramming = bool( route.iif != "Register0" )
         route.routeFlags = newrouteFlags
      if intfFlags.accept:
         route.iif = Tac.Value( "Arnet::IntfId", intfName )
      route.routeReady = RouteReadyEnum.ready

   def delIntf( self, routeKey, intfName, intfFlags ):
      intfName = intfName.strip()
      route = self.route( routeKey )
      route.routeReady = RouteReadyEnum.notReady

      if isinstance( intfFlags, int ):
         intfFlags = Tac.Value( "Routing::Multicast::Fib::IntfFlags", intfFlags )

      if intfFlags.forward:
         route = self.delOif( route, intfName )
      elif intfFlags.notify:
         route.notify = Tac.Value( "Arnet::IntfId", '' )
      if intfFlags.accept:
         route.iif = Tac.Value( "Arnet::IntfId", '' )

      route.routeReady = RouteReadyEnum.ready

   def addOif( self, route, intfName ):
      intf = Tac.Value( "Arnet::IntfId", intfName ) \
                              if isinstance( intfName, str ) else intfName
      if intf not in route.oif:
         route.oif.add( intf )
      return route

   def delOif( self, route, intfName ):
      intf = Tac.Value( "Arnet::IntfId", intfName ) \
                              if isinstance( intfName, str ) else intfName
      if intf in route.oif:
         del route.oif[ intf ]
      return route

   def routeOifIter( self, key ):
      route = self.route( key )
      return iter( route.oif )
