#!/usr/bin/env python3
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=superfluous-parens

import collections
import Tac, Tracing
from IpLibConsts import DEFAULT_VRF
import SharedMem
import Arnet
import queue

handle = Tracing.Handle( "MfibSmashTestLib" )

t0 = handle.trace0
t1 = handle.trace1

FORWARD = 1
ACCEPT = 2
NOTIFY = 4

class MfibSmashTestEnv:
   _id = 0

   def __init__( self, mfibSmash=None, entityManager=None, vrfName=DEFAULT_VRF,
                 appName="sparsemode", bitMapperConfig=None, bitMapperStatus=None,
                 bitMapperSmashStatus=None, af="ipv4" ):
      self.smashMount = None
      self.mapper = None
      if entityManager:
         self.smashMount = SharedMem.entityManager( sysdbEm=entityManager )

      self.appName = appName
      self.vrfName = vrfName
      self.af = af

      self.typeStr = "Smash::Multicast::Fib::Status"
      self.path = ""

      if mfibSmash:
         self.mfibSmash = mfibSmash
      else:
         if self.smashMount is not None:
            t1( "Mounting Smash as write" )
            self.entityManager = entityManager
            # create the Smash mount group
            tableInfo = Tac.Value( "Smash::Multicast::Fib::TableInfo" )
            mfibMountInfo = tableInfo.mfibMountInfo( "writer" )

            self.path = Tac.Type( self.typeStr ).mountPath( self.af, self.vrfName,
                                                            self.appName )
            self.mfibSmash = self.smashMount.doMount( self.path, self.typeStr,
                                                      mfibMountInfo )
         else:
            self.mfibSmash = Tac.newInstance( self.typeStr, "Mfib" )

      self.mfib = self.mfibSmash
      if bitMapperSmashStatus:
         self.loadMapper( bitMapperConfig, bitMapperStatus, bitMapperSmashStatus )
      else:
         self.loadMapper()

      MfibSmashTestEnv._id += 1

   def __del__( self ):
      if self.smashMount is not None:
         self.smashMount.doUnmount( self.path )

   def clear( self ):
      if self.mfibSmash:
         self.mfibSmash.route.clear()

   @staticmethod
   def ipGenRouteKey( s, g ):
      hostLen = 32 if '.' in s else 128
      if s == '0.0.0.0' or s == '::': # pylint: disable=consider-using-in
         s += '/0'
      else:
         s += '/%d' % hostLen # pylint: disable=consider-using-f-string
      return Tac.Value( "Routing::Multicast::Fib::IpGenRouteKey",
                        Arnet.IpGenPrefix( s ), Arnet.IpGenPrefix( g ) )

   def iif( self, key ):
      r = self.route( key )
      if r:
         return r.iif
      else:
         return None

   def notify( self, key ):
      r = self.route( key )
      if r and r.notify:
         return r.notify
      else:
         return None

   def ptr( self ):
      return Tac.addressOf( self.mfibSmash )

   def route( self, key ):
      r = self.mfibSmash.route.get( key )
      return r

   def routeIter( self ):
      return self.mfibSmash.route.values()

   def oifs( self, key ):

      oifs = []
      for intfId in self.RouteOifIter( self, key ): # pylint: disable=not-an-iterable
         oifs.append( intfId )
      return oifs

   def isBlankRoute( self, key ):
      return self.oifs( key ) or self.route( key ).iif != "Null0" or \
                                 self.route( key ).notify != ""

   def newRoute( self, routeKey, *args, **kwargs ):
      if routeKey:
         route = Tac.Value( "Smash::Multicast::Fib::Route", routeKey )
      else:
         return None

      iif = ""
      notify = ""
      oifs = [ ]
      staticFastdrop = ""
      notifyProgrammable = False
      frr = ""
      rpaId = 0

      priority = kwargs[ 'priority' ] if kwargs.get( 'priority' ) else 1
      programmable = kwargs.get( 'programmable', True )
      iifPmsi = kwargs.get( 'iifPmsi', False )
      toCpu = kwargs[ 'toCpu' ] if kwargs.get( 'toCpu' ) else False
      iifFlags = kwargs.get( 'iifFlags', None )
      routeFlags = kwargs.get( 'routeFlags', None )

      try:
         iif = args[ 0 ]
         notify = args[ 1 ]
         oifs = args[ 2 ]
         staticFastdrop = args[3]
         notifyProgrammable = args[4]
         frr = args[5]
         rpaId = args[ 6 ]
      except IndexError:
         pass

      if programmable:
         newRouteFlags = Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                                    route.routeFlags.value )
         newRouteFlags.programmable = True
         route.routeFlags = newRouteFlags

      if toCpu:
         newRouteFlags = Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                                    route.routeFlags.value )
         newRouteFlags.toCpu = True
         route.routeFlags = newRouteFlags

      if iif == "Register0":
         route.iif = "Register0"
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
               route.routeFlags.value )
         newrouteFlags.iifRegister = True
         route.routeFlags = newrouteFlags
      elif iif:
         route.iif = Tac.Value( "Arnet::IntfId", iif )

      if notify:
         route.notify = Tac.Value( "Arnet::IntfId", notify )
         route.notifyProgrammable = notifyProgrammable
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                          route.routeFlags.value )
         if iif != "Register0":
            newrouteFlags.skipSoftwareProgramming = True
         else:
            newrouteFlags.skipSoftwareProgramming = False
         route.routeFlags = newrouteFlags

      if staticFastdrop:
         route.fastdrop = Tac.Value( "Arnet::IntfId", staticFastdrop )

      if frr:
         route.iifFrr = Tac.Value( "Arnet::IntfId", frr )

      if iifPmsi:
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
               route.routeFlags.value )
         newrouteFlags.iifPmsi = True
         route.routeFlags = newrouteFlags

      if iifFlags:
         route.iifFlags = Tac.Value( "Routing::Multicast::Fib::IifFlags", iifFlags )

      if routeFlags:
         route.routeFlags = Tac.Value( "Routing::Multicast::Fib::RouteFlags",
               routeFlags )

      if oifs:
         for intfName in oifs:
            intfName = intfName.strip()
            # pylint: disable-next=assignment-from-no-return
            route = self.addOif( route, intfName )
      route.rpaId = int(rpaId)
      route.routePriority = priority
      #print "iif=%s, notify=%s, oifs=%s" % ( iif, notify, oifs )
      self.mfibSmash.addRoute( route )
      return route

   def deleteRoute( self, routeKey ):
      del self.mfibSmash.route[ routeKey ]

   def addIntf( self, routeKey, intfName, intfFlags ):
      intfName = intfName.strip()
      route = self.route( routeKey )
      copyRoute = Tac.nonConst( route )

      if isinstance( intfFlags, int ):
         intfFlags = Tac.Value( "Routing::Multicast::Fib::IntfFlags", intfFlags )

      if ( intfFlags.forward ):
         # pylint: disable-next=assignment-from-no-return
         copyRoute = self.addOif( copyRoute, intfName )
      elif ( intfFlags.notify ):
         copyRoute.notify = Tac.Value( "Arnet::IntfId", intfName )
         newrouteFlags = \
               Tac.Value( "Routing::Multicast::Fib::RouteFlags",
                          copyRoute.routeFlags.value )
         if copyRoute.iif != "Register0":
            newrouteFlags.skipSoftwareProgramming = True
         else:
            newrouteFlags.skipSoftwareProgramming = False
         copyRoute.routeFlags = newrouteFlags
      if ( intfFlags.accept ):
         copyRoute.iif = Tac.Value( "Arnet::IntfId", intfName )

      self.mfibSmash.addRoute( copyRoute )

   def delIntf( self, routeKey, intfName, intfFlags ):
      intfName = intfName.strip()
      route = self.route( routeKey )
      copyRoute = Tac.nonConst( route )

      if isinstance( intfFlags, int ):
         intfFlags = Tac.Value( "Routing::Multicast::Fib::IntfFlags", intfFlags )

      if ( intfFlags.forward ):
         # pylint: disable-next=assignment-from-no-return
         copyRoute = self.delOif( copyRoute, intfName )
      elif ( intfFlags.notify ):
         copyRoute.notify = Tac.Value( "Arnet::IntfId", '' )
      if ( intfFlags.accept ):
         copyRoute.iif = Tac.Value( "Arnet::IntfId", '' )

      self.mfibSmash.addRoute( copyRoute )

   def oifIndexIntf( self, key ):
      return self.mfibSmash.oifIndexIntf[ key ]

   def oifIndexIntfIter( self ):
      return self.mfibSmash.oifIndexIntf.values()

   # Following function and the function stubs that follow are an effort to
   # make this a polymorphic class in runtime. Users can simply instantiate 
   # the MfibSmashTestEnv and the appropriate mapper functions will be loaded 
   # depending upon the parameters specified.  It is also possible to load 
   # a different mapper during runtime.
   # This also lets us isolate the new BitMapper env functions in a separate
   # class while not changing the existing breadth tests which use the legacy 
   # internal mapper.

   def loadMapper( self, bitMapperConfig=None, bitMapperStatus=None, 
                                                  bitMapperSmashStatus=None ):
      if bitMapperSmashStatus:
         self.mapper = BitMapper( bitMapperConfig, bitMapperStatus, 
                                                      bitMapperSmashStatus )
      else:
         self.mapper = InternalMapper( self.mfibSmash )

      # import all functions from the mapper class to this class
      methods = [ func for func in dir( self.mapper ) 
         if callable( getattr( self.mapper, func ) ) and not func.startswith('__') ]

      for func in methods:
         setattr( self, func, getattr( self.mapper, func ) )

   def createMapping( self, index, intf ):
      pass

   def getOifMapping( self, intf ):
      pass

   def addOif( self, route, intf ):
      pass

   def delOif( self, route, intf ):
      pass

   class RouteOifIter:
      pass


class InternalMapper:
   def __init__( self, mfibSmash ):
      t0( "Initializing Internal Mapper Test Env" )
      self.mfibSmash = mfibSmash
      self.oifIndexFreeList = queue.Queue( 1024 )
      for i in range( 0, 512 ):
         oifIndex = Tac.Value( "Smash::Multicast::Fib::OifIndex", i )
         self.oifIndexFreeList.put( oifIndex )

      self.intfIdOifIndex = {}

   def createMapping( self, idx, intfName ):
      oifIndex = Tac.Value( "Smash::Multicast::Fib::OifIndex", idx )
      self.intfIdOifIndex[ intfName ] = oifIndex
      intfId = Tac.Value( "Arnet::IntfId", intfName )
      oifIndexIntf = Tac.Value( "Smash::Multicast::Fib::OifIndexIntf", 
                                                         oifIndex, intfId )
      self.mfibSmash.addOifIndexIntf( oifIndexIntf ) 

   def getOifMapping( self, intfName ):
      if intfName not in self.intfIdOifIndex:
         return None
      oifIndex = self.intfIdOifIndex[ intfName ]
      defaultOifIndex = Tac.Value( "Smash::Multicast::Fib::OifIndex" )
      if oifIndex == defaultOifIndex:
         return None
      return oifIndex

   def addOif( self, route, intfName ):
      if intfName in self.intfIdOifIndex:
         oifIndex = self.intfIdOifIndex[ intfName ]
      else:
         oifIndex = self.oifIndexFreeList.get()
         self.createMapping( oifIndex.value, intfName )

      route.oif[ oifIndex ] = True
      if( ( oifIndex.value > route.maxOifIndex ) or ( route.maxOifIndex == 512 ) ):
         route.maxOifIndex = oifIndex.value
      return route

   def delOif( self, route, intfName ):
      if intfName in self.intfIdOifIndex:
         oifIndex = self.intfIdOifIndex[ intfName ]
         del route.oif[ oifIndex ]
      return route

   class RouteOifIter:
      def __init__( self, mfibSmashEnv, key ):
         self.key_ = key
         self.mfibSmashEnv_ = mfibSmashEnv
         self.route_ = self.mfibSmashEnv_.route( key )
         self.oifIndexNull_ = Tac.Value( "Smash::Multicast::Fib::OifIndex", 512 )
         self.curOifIndex_ = self.oifIndexNull_

      def __iter__( self ):
         return self

      def __next__( self ):
         if not self.route_:
            raise StopIteration()
         self.curOifIndex_ = self.route_.nextOif( self.curOifIndex_ )
         if self.curOifIndex_ < 512:
            intfId = self.mfibSmashEnv_.oifIndexIntf( self.curOifIndex_ ).intf
            return intfId
         else:
            raise StopIteration()

      next = __next__ # for py2 compatibility

class BitMapper:
   def __init__( self, bitMapperConfig, bitMapperStatus, bitMapperSmashStatus ):

      if not ( bitMapperConfig and bitMapperStatus and bitMapperSmashStatus ):
         t0( "Trying to use BitMapper with insufficient arguments" )

      t0( "Initializing BitMapper Test Env" )
      self.config_ = bitMapperConfig
      self.status_ = bitMapperStatus
      self.smashStatus_ = bitMapperSmashStatus
      self.ipGenAddrDefault = Tac.Value( "Arnet::IpGenAddr", "" )

   def newMapping( self, bmKey ):
      self.config_.allocRequestKey = bmKey
      count = self.config_.allocRequestCount
      count = count + 1
      self.config_.allocRequestCount = count
      return self.status_.bitIndex[ bmKey ]

   def createMapping( self, idx, intfName ):
      """This function will force the mapping specified in its parameters"""
      raise NotImplementedError

   def addOif( self, route, intf ):
      bmKey = Tac.Value( "BitMapper::BitMapperKey", intf, self.ipGenAddrDefault )
      if bmKey in self.status_.bitIndex:
         bitIndex = self.status_.bitIndex[ bmKey ]
      else:
         bitIndex = self.newMapping( bmKey )

      route.oif[ bitIndex ] = True
      if( ( bitIndex > route.maxOifIndex ) or ( route.maxOifIndex == 512 ) ):
         route.maxOifIndex = bitIndex
      return route

   def delOif( self, route, intf ):
      bmKey = Tac.Value( "BitMapper::BitMapperKey", intf, self.ipGenAddrDefault )
      if bmKey in self.status_.bitIndex:
         bitIndex = self.status_.bitIndex[ bmKey ]
         del route.oif[ bitIndex ]
      return route

   class RouteOifIter:
      def __init__( self, mfibSmashEnv, key ):
         self.key_ = key
         self.mfibSmashEnv_ = mfibSmashEnv
         self.route_ = self.mfibSmashEnv_.route( key )
         self.bitIndexNull_ = 512
         self.curBitIndex_ = self.bitIndexNull_
         self.bmSmashStatus_ = None

      def __iter__( self ):
         self.bmSmashStatus_ = self.mfibSmashEnv_.mapper.smashStatus_
         return self

      def __next__( self ):
         self.curBitIndex_ = self.route_.nextOif( self.curBitIndex_ ) 
         if self.curBitIndex_ < 512:
            biMap = self.bmSmashStatus_.bitIndexMap[ self.curBitIndex_ ]
            # pylint: disable-next=consider-using-f-string
            t0( "Mroute OIF iter returning %s %s" %
                ( biMap.bitMapperKey.intfId, self.curBitIndex_ ) )
            return biMap.bitMapperKey.intfId
         else:
            # pylint: disable-next=consider-using-f-string
            t0( "Mroute OIF iter returning terminating %s" % self.curBitIndex_ )
            raise StopIteration()

      next = __next__ # for py2 compatibility

class BidirStatusTestEnv:

   def __init__( self, status=None ):
      if status:
         self.status = status
      else:
         self.status = Tac.newInstance( "Smash::Multicast::Fib::BidirStatus", "bss" )

   def bidirGroup( self, g ):
      
      if isinstance( g, str ):
         g = Arnet.IpGenPrefix( g )

      return self.status.bidirGroup.get( g )

   def bidirGroupIs( self, g, rpa=0 ):
      
      if isinstance( g, str ):
         g = Arnet.IpGenPrefix( g )

      bidirGroup = Tac.Value( "Smash::Multicast::Fib::BidirGroup", g )
      bidirGroup.rpa = rpa

      return self.status.bidirGroup.addMember( bidirGroup )

   def iAmDf( self, intfId, rpId ):

      dfProfile = self.status.dfProfile.get( intfId )
      if not dfProfile:
         return False

      dfBitmask = 1 << rpId
      return dfProfile.dfBitmap & dfBitmask != 0

   def iAmDfIs( self, intfId, rpId, iAmDf=True ):

      dfProfile = self.status.dfProfile.get( intfId )
      if dfProfile is not None:
         dfProfile = Tac.nonConst( dfProfile )
      else:
         dfProfile = Tac.Value( "Smash::Multicast::Fib::DfProfile", intfId )

      dfBitmask = 1 << rpId
      if iAmDf:
         dfProfile.dfBitmap = dfProfile.dfBitmap | dfBitmask
      else:
         dfProfile.dfBitmap = dfProfile.dfBitmap & ~dfBitmask

      self.status.dfProfile.addMember( dfProfile )
      return dfProfile

MfibRoute = collections.namedtuple( "MfibRoute",
                                    [ "iif", "oifset", "notify", "programmable" ] )
MfibRoute.__new__.__defaults__ = ( "", [], "", True )

class MfibClientDirState:
   ''' Stateful representation of collection of MfibSmash tables per client
   :name : MfibClient Name
   :af : Arnet::AddressFamily
   : vrf : dictionary of per VRF Mfibs
         { vrf : { (S,G) : ( Iif, [Oif], notify, programmable )} }
      This cant be modified to account for Sharked Mfib once that is done
   '''

   def __init__( self, name, af, state ):
      self.name = name
      self.af = af
      self.vrfs = {}
      for vrf, routes in state.items():
         self.vrfs[ vrf ] = {
            key: MfibRoute( *route )
            for key, route in routes.items()
         }
      self.env = {}
      self.appType = { 'static': 'user', 'pimsm': 'sparsemode', 'bidir': 'bidir' }

   def toTacc( self, tacStatus, entityManager, reconcile=True ):
      clientDir = tacStatus.newEntity( "Smash::Multicast::Fib::MfibDir", self.name )
      for vrf, mfib in self.vrfs.items():
         vrfEnv = self.env[ vrf ] = MfibSmashTestEnv( entityManager=entityManager,
                                                      vrfName=vrf,
                                                      appName=self.name,
                                                      af=self.af )
         mfibDirColl = clientDir.newDirColl( vrf )
         # pylint: disable-next=consider-using-f-string
         mfibDirColl.appType = "%s_%s" % ( self.appType[ self.name ],
               'v4' if self.af == 'ipv4' else 'v6' )
         mfibDirColl.smashUrl = vrfEnv.path

         for key, route in mfib.items():
            tacKey = vrfEnv.ipGenRouteKey( *key )
            vrfEnv.newRoute( tacKey,
                             route.iif,
                             route.notify,
                             route.oifset,
                             programmable=route.programmable )
         if reconcile:
            for tacRoute in vrfEnv.routeIter():
               tacKey = tacRoute.key
               key = ( str( tacKey.s.ipGenAddr ), str( tacKey.g.ipGenAddr ) )
               if key not in mfib:
                  vrfEnv.deleteRoute( tacKey )
      if reconcile:
         for vrf in clientDir.dirColl:
            if vrf not in self.vrfs:
               clientDir.deleteEntity( vrf )
