# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliPrint
from CliPrint import (
   IntAttrInfo,
   StringAttrInfo,
)
import Tac

# Will be initialized by the user of Bessctl feature
bessctl = None

class TacEnum:
   def __init__( self, typeName ):
      t = Tac.Type( typeName )
      for a in t.attributes:
         setattr( self, a, Tac.enumValue( t, a ) )

ntpWarning = 'Warning: NTP synchronization is required for ' \
             '1-Way time measurement accuracy.\n'
mhPathFlag = 0x80000000
PingReqTlv = Tac.Type( "WanTEUtils::AvtUtilPingReqTlv" )
TraceRouteReqTlv = Tac.Type( "WanTEUtils::AvtUtilTraceRouteRequestTlv" )
TlvParser = Tac.Type( "WanTEUtils::AvtUtilTlvParser" )
TlvCode = TacEnum( "WanTEUtils::AvtUtilTlvType" )
ErrorCode = TacEnum( "WanTEUtils::AvtUtilErrorCode" )

# pylint: disable=no-member
errorStr = {
   ErrorCode.ttlExpired: "TTL expired",
   ErrorCode.invalidPathId: "path ID is invalid",
   ErrorCode.pathIdNotPartOfAvt: "path ID is no longer accessible in the "
                                 "selected VRF and AVT",
   ErrorCode.nextHopLabelDown: "next hop label is down",
   ErrorCode.noPathToDestination: "no path is available to reach next hop",
   ErrorCode.lbLookupFailure: "destination not reachable in the selected "
                              "VRF and AVT",
   ErrorCode.noPktSpaceLeft: "no space left in packet",
   ErrorCode.missingInfo: "missing data in request",
   ErrorCode.pktLost: "destination not reachable",
   ErrorCode.configError: "",
}

def checkDest( client, dest ):
   cmdFmt = '''
from Arnet import IpGenAddr
destVtep = IpGenAddr( "%s" )
remoteList = entity( "%s/Sfe" ).l3Comp.l3AgentSm.wanTECommonDir.wanTERemoteIpList
print( destVtep in remoteList )
'''
   agentPath = client.agentRoot().fullName
   cmd = cmdFmt % ( dest, agentPath )
   checkResult = client.execute( cmd )
   return 'True' in checkResult

def constructPam( src, dest, ttl=64, txPort=3503 ):
   udpPam = Tac.newInstance( "Arnet::UdpPam", "AvtUtilPam" )
   udpPam.txSrcIpAddr = src
   udpPam.txDstIpAddr = dest
   udpPam.txPort = txPort
   udpPam.rxIpAddr = "0.0.0.0"
   udpPam.txTTL = ttl
   udpPam.mode = 'server'
   return udpPam

def mhPath( intStr ):
   return int( intStr ) | mhPathFlag

def getTimestamp():
   # Tac.utcNow() returns <sec>.<us>
   # convert to us
   return int( Tac.utcNow() * 1000 * 1000 )

def mhPathLabelStr( labels ):
   if labels:
      return '[ ' + ",".join( str( label ) for label in labels ) + ' ]'
   return "[-]"

def isMhPath( pathId ):
   return bool( pathId & mhPathFlag )

def pathStr( pathId, printLabel=False, hopList=None ):
   mPath = bool( pathId & mhPathFlag )
   index = pathId & ~mhPathFlag
   pathType = 'multihop' if mPath else 'direct'
   txt = "via " + pathType + " path " + str( index )
   labels = ""
   if printLabel and mPath:
      labels = mhPathLabelStr( hopList )
      labels = ', label stack ' + labels
   return txt + labels

def strPathId( pathId ):
   path = f"direct:{ pathId }"
   if isMhPath( pathId ):
      # mhPath
      mhPathId = pathId & ~mhPathFlag
      path = f"multihop:{ mhPathId }"
   return path

def printLabels( printer, mhPathId, mhPathLabels ):
   labelEntry = IntAttrInfo( printer, "" )
   with printer.list( "labelStack" ):
      for label in mhPathLabels:
         labelEntry.setValue( int( label ) )
         printer.addAttributes( "%d", labelEntry.getArg() )

def printError( printer, errCode, errMsg ):
   printer.addFrills( "%s\n", errMsg )
   if printer.outputFormat == CliPrint.JSON:
      if errCode is not None:
         errorCode = IntAttrInfo( printer, "errorCode" )
         errorCode.setValue( errCode )
         printer.addAttributes( "%d", errorCode.getArg() )
      if errMsg:
         errorMessage = StringAttrInfo( printer, "errorMessage" )
         errorMessage.setValue( errMsg )
         printer.addAttributes( "%s\n", errorMessage.getArg() )
