# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import SharedMem
import Smash
import SmashLazyMount
import Tac
from TypeFuture import TacLazyType

DynamicTunnelIntfId = TacLazyType( 'Arnet::DynamicTunnelIntfId' )
TunnelIdConstants = TacLazyType( "Tunnel::TunnelTable::TunnelIdConstants" )
TunnelTableMounter = TacLazyType( "Tunnel::TunnelTable::TunnelTableMounter" )
TunnelType = TacLazyType( 'Tunnel::TunnelTable::TunnelType' )
TunnelViaStatus = TacLazyType( 'Tunnel::Hardware::TunnelViaStatus' )

staticTunnelTypeStrs = [ 'staticTunnel', 'staticV4Tunnel', 'staticV6Tunnel' ]
srTunnelTypeStrs = [ 'srTunnel', 'srV4Tunnel', 'srV6Tunnel' ]

tunnelViaStatusTacToCapi = {
   TunnelViaStatus.unresolved: 'unresolved',
   TunnelViaStatus.usingPrimaryVias: 'usingPrimaryVias',
   TunnelViaStatus.usingBackupVias: 'usingBackupVias',
   TunnelViaStatus.unknown: 'unknown',
}

# The 44th bit is used to differentiate between tunnel endpoints of the same
# type but different address families.
tunnelIdAfBitMask = TunnelIdConstants.tunnelAfMask
tunnelIndexMask = TunnelIdConstants.tunnelIndexMask

def getNhAndIntfStrs( via ):
   if via.type == 'tunnel':
      nhStr = via.tunnelId.renderStr()
      intfStr = '-'
   else:
      nhStr = str( via.nexthop )
      intfStr = via.interface.stringValue
   return nhStr, intfStr

def isDyTunIntfId( intfId ):
   '''Return True if a given intfId from Via Model is a dynamic Interface'''
   return DynamicTunnelIntfId.isDynamicTunnelIntfId( intfId )

def getDyTunTidFromIntfId( intfId ):
   '''Create a tacc DynamicTunnelIntfId from an intfId in Via Model'''
   assert isDyTunIntfId( intfId )
   return DynamicTunnelIntfId.tunnelId( intfId )

def getTunnelTypeEnumVal( tunnelType ):
   # TunnelType enum is defined in TunnelBasicTypes.tac
   if tunnelType in staticTunnelTypeStrs:
      return "staticTunnel"
   elif tunnelType in srTunnelTypeStrs:
      return "srTunnel"
   return tunnelType

def getTunnelIdFromIndex( tunnelType, index, af=None ):
   """ Given a tunnelType and a (unique) index number, returns a
   Tunnel::TunnelTable::TunnelId.

   INPUTS
   - tunnelType (str): an enum name as defined in TunnelBasicTypes.tac, or one
                       of the more specific strings { 'staticV4Tunnel',
                       'staticV6Tunnel', ... } defined in this file.
   - index (int/long): an integer represented by at most 43 bits. The 43-bit
                       limit comes from the implementation of TunnelId. The 44th
                       bit of index is used to denote whether the tunnel is
                       IPv4 or IPv6.

   OUTPUTS
   - A Tunnel::TunnelTable::TunnelId instance. This is the unique identifier of
     all tunnels.
   """
   tunnelTypeEnumVal = getTunnelTypeEnumVal( tunnelType )
   # pylint: disable-next=consider-using-in
   if tunnelType == 'staticV6Tunnel' or tunnelType == 'srV6Tunnel':
      af = 'ipv6'

   # FIXME BUG200023: shift this logic into TunnelId.convertToTunnelValue.
   # Use the 44th bit set aside for the index in order to differentiate
   # between IPv4 and IPv6 tunnels.
   if af == 'ipv6':
      index = index | tunnelIdAfBitMask

   return Tac.Value( "Tunnel::TunnelTable::TunnelId" ) \
            .convertToTunnelValue( tunnelTypeEnumVal, index )


def getTunnelIdWithAf( tunnelIndex, tunnelType, tunnelTable ):
   # For Tunnel types with address-family bit set, the tunnel with corresponding
   # tunnelIndex can be v4 tunnel or v6 tunnel but not both
   tunnelIdV4 = getTunnelIdFromIndex( tunnelType, tunnelIndex )
   if tunnelIdV4 in tunnelTable.entry:
      return tunnelIdV4
   tunnelIdV6 = getTunnelIdFromIndex( tunnelType, tunnelIndex, af='ipv6' )
   return tunnelIdV6

def getFullTunnelTableMountInfo( tunnelTableId ):
   return TunnelTableMounter.getMountInfo( tunnelTableId )

def getTunnelTableMountInfo( tunnelTableId ):
   return getFullTunnelTableMountInfo( tunnelTableId ).tableInfo

def mountTunnelTable( entityManager, tableInfo, mountInfo, lazy=True ):
   if lazy:
      return SmashLazyMount.mount(
         entityManager, tableInfo.mountPath, tableInfo.tableType, mountInfo )
   else:
      shmemEm = SharedMem.entityManager( sysdbEm=entityManager )
      return shmemEm.doMount( tableInfo.mountPath, tableInfo.tableType, mountInfo )

def readMountTunnelTable( tunnelTableId, entityManager, lazy=True ):
   """NOTE: Make sure to set lazy=False if passing handle to a C++ function."""
   tableInfo = getTunnelTableMountInfo( tunnelTableId )

   return mountTunnelTable(
      entityManager, tableInfo, Smash.mountInfo( "reader" ), lazy=lazy )

def writeMountTunnelTable( tunnelTableId, entityManager, lazy=True ):
   """NOTE: Make sure to set lazy=False if passing handle to a C++ function."""
   tableMountInfo = getFullTunnelTableMountInfo( tunnelTableId )
   tableInfo = tableMountInfo.tableInfo
   writerInfo = tableMountInfo.writerMountInfo

   return mountTunnelTable( entityManager, tableInfo, writerInfo, lazy=lazy )

def keyshadowMountTunnelTable( tunnelTableId, entityManager, lazy=True ):
   """NOTE: Make sure to set lazy=False if passing handle to a C++ function."""
   tableInfo = getTunnelTableMountInfo( tunnelTableId )

   return mountTunnelTable(
      entityManager, tableInfo, Smash.mountInfo( "keyshadow" ), lazy=lazy )

def getTunnelViaStatusFromTunnelId( tunnelId, programmingStatus, multicast=False ):
   if multicast:
      return None
   tacTunnelId = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId )
   # GUE tunnels do not program tunnel status.
   # All entries in GUE tunnel fib are resolved and include 'Primary' vias
   if tacTunnelId.tunnelType() == TunnelType.gueTunnel:
      return tunnelViaStatusTacToCapi[ TunnelViaStatus.usingPrimaryVias ]
   if tacTunnelId.tunnelType() == TunnelType.voqFabricTunnel:
      return None
   tunnelStatus = programmingStatus.tunnelStatus.get( tunnelId )
   if tunnelStatus:
      return tunnelViaStatusTacToCapi[ tunnelStatus.tunnelViaStatus ]
   else:
      return 'notProgrammed'

