# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=protected-access
# pylint: disable=consider-merging-isinstance

#-------------------------------------------------------------------------------
# This module implements common utility functions for MPLS
#-------------------------------------------------------------------------------
'''Common utility functions for MPLS'''
import random

import CliCommand
import CliMatcher
import CliModel
import Tac

labelMin = Tac.Type( 'Arnet::MplsLabel').unassignedMin
labelMax = Tac.Type( 'Arnet::MplsLabel').max
MplsStackEntryIndex = Tac.Type( 'Arnet::MplsStackEntryIndex' )
MAX_LABEL_STACK_SIZE = MplsStackEntryIndex.max + 1
UnboundedMplsLabelOperation = Tac.Type( 'Arnet::UnboundedMplsLabelOperation' )
BoundedMplsLabelStack = Tac.Type( 'Arnet::BoundedMplsLabelStack' )
DynamicMplsLabelStack = Tac.Type( 'Arnet::DynamicMplsLabelStack' )
labelValMatcher = CliMatcher.IntegerMatcher( labelMin, labelMax,
                                             helpdesc='Value of the MPLS label' )
zeroLabelMatcher = CliMatcher.KeywordMatcher(
   '0', helpdesc='0: Explicit IPV4 MPLS null label' )
twoLabelMatcher = CliMatcher.KeywordMatcher(
   '2', helpdesc='2: Explicit IPV6 MPLS null label' )
# This sharedMatchObj is used to limit the number of mpls labels that are matched
# by a given LabelValsWithExpNullExprFactory.
labelSharedMatchObj = object()
ELI = Tac.Type( "Arnet::MplsLabel" ).entropyLabelIndicator
IMP_NULL = Tac.Type( "Arnet::MplsLabel" ).implicitNull
MplsHdrSize = 4

class LabelValsWithExpNullExprFactory( CliCommand.CliExpressionFactory ):
   def __init__( self, maxMatches=0, maxMatchesFunc=None ):
      CliCommand.CliExpressionFactory.__init__( self )
      self.maxMatches = maxMatches
      self.maxMatchesFunc = maxMatchesFunc

   def getMaxMatches( self ):
      if self.maxMatchesFunc:
         return self.maxMatchesFunc()
      return self.maxMatches

   def generate( self, name ):
      class LabelValWithExpNullExpr( CliCommand.CliExpression ):
         expression = '{ ZERO | TWO | INTERNAL_LABEL }'
         data = {
            'ZERO' :
            CliCommand.Node( matcher=zeroLabelMatcher,
                             maxMatches=self.getMaxMatches(),
                             sharedMatchObj=labelSharedMatchObj ),
            'TWO' :
            CliCommand.Node( matcher=twoLabelMatcher,
                             maxMatches=self.getMaxMatches(),
                             sharedMatchObj=labelSharedMatchObj ),
            'INTERNAL_LABEL' :
            CliCommand.Node( matcher=labelValMatcher,
                             maxMatches=self.getMaxMatches(),
                             sharedMatchObj=labelSharedMatchObj )
         }

         @staticmethod
         def adapter( mode, args, argsList ):
            if name in args:
               return

            result = []
            for i in argsList:
               if i[ 0 ] not in ( 'ZERO', 'TWO', 'INTERNAL_LABEL' ):
                  continue
               result.append( i[ 1 ] )
            args[ name ] = result
      return LabelValWithExpNullExpr

labelsUsed = set()
def randomLabel( uniqueLabel=True ):
   label = random.randint( labelMin, labelMax )
   if uniqueLabel:
      # When unique label is requested, make sure unused labels are available to
      # avoid infinite loop below.
      assert len( labelsUsed ) <  labelMax - labelMin + 1
   while uniqueLabel and label in labelsUsed:
      label = random.randint( labelMin, labelMax )
   labelsUsed.add( label )
   return label

def genLabelStack( index, maxStackSize, genV4ExpNull=False, genV6ExpNull=False,
                   uniqueLabel=True ):
   stackSize = ( index % maxStackSize ) + 1
   stack = [ randomLabel( uniqueLabel=uniqueLabel ) for _ in range( stackSize ) ]
   if genV4ExpNull:
      # Change BOS label to 0
      stack[ stackSize - 1 ] = 0
   if genV6ExpNull:
      # Change BOS label to 2
      stack[ stackSize - 1 ] = 2
   return stack

def genLabelStacks( size, maxStackSize ):
   """Generate a list containing 'size' number of MPLS label stacks"""

   if not maxStackSize:
      # Platforms without MPLS push support have maxStackSize=0
      return None

   # Ensure 'size' input parameter is an integer
   assert isinstance( size, int )
   labelStacks = []
   for i in range( size ):
      labelStacks.append( genLabelStack( i, maxStackSize ) )
   return labelStacks

def labelListToMplsLabelStack( labelList, const=True, dynamicLabelStack=False ):
   labelStack = DynamicMplsLabelStack() if dynamicLabelStack else \
         BoundedMplsLabelStack()
   for idx, label in enumerate( reversed( labelList ) ):
      labelStack.labelIs( idx, label )
   labelStack.stackSize = len( labelList )
   return Tac.const( labelStack ) if const else labelStack

def boundedMplsLabelStackToLabelList( labelStack ):
   labelList = []
   for i in range( labelStack.stackSize ):
      labelList.append( labelStack.label( i ) )
   labelList.reverse()
   return labelList

def labelStackToMplsLabelOperation( labelList, operation='push', controlWord=False,
                                    entropyLabelPositionBitMap=None,
                                    addFlowLabel=False,
                                    unboundedMplsLabelOperation=False ):
   mplsLabelOperation = Tac.newInstance( 'Arnet::MplsLabelOperation' ) if not \
         unboundedMplsLabelOperation else \
         Tac.newInstance( 'Arnet::UnboundedMplsLabelOperation' )
   for idx, value in enumerate( labelList ):
      mplsLabelOperation.labelStackIs( idx, value )
   elPositions = entropyLabelPositionBitMap or []
   for elPos in elPositions:
      mplsLabelOperation.entropyLabelAndIndicatorIs( elPos )
   mplsLabelOperation.operation = operation
   mplsLabelOperation.controlWord = controlWord
   mplsLabelOperation.addFlowLabel = addFlowLabel
   mplsLabelOperation.stackSize = len( labelList )
   return mplsLabelOperation

def ConvertLabelStack( labelStack, const=True ):
   if isinstance( labelStack, DynamicMplsLabelStack ):
      convertedLabelStack = BoundedMplsLabelStack()
   else:
      convertedLabelStack = DynamicMplsLabelStack()
   stackSize = convertedLabelStack.stackSize = labelStack.stackSize
   index = 0
   while stackSize:
      convertedLabelStack.labelIs( index, labelStack.label( index ) )
      index += 1
      stackSize -= 1
   convertedLabelStack.entropyLabelBitMap = labelStack.entropyLabelBitMap
   convertedLabelStack.entropyLabelIndicatorBitMap = \
         labelStack.entropyLabelIndicatorBitMap
   constFunc = Tac.const if const else Tac.nonConst
   return constFunc( convertedLabelStack )

def genMplsLabelOperationList( size, maxStackSize ):
   return list( map( labelStackToMplsLabelOperation,
                    genLabelStacks( size, maxStackSize ) ) )

def getLabelStack( mplsLabelOperation ):
   """Generate a list of MPLS labels from a MplsLabelOperation value type.
      In the value type, the stack will be stored as the following:
      
      BOTTOM OF STACK (BOS)
      label0: ...
      label1: ...
      label2: ...
      TOP OF STACK (TOS)
      
      The list that is returned will look like the following:
      TOS ---> [ label2, label1, label0 ] <--- BOS"""

   # Ensure that input parameter is of type 'MplsLabelOperation'
   assert isinstance( mplsLabelOperation, Tac.Type( 'Arnet::MplsLabelOperation' ) )
   labelStack = []
   for i in range( mplsLabelOperation.stackSize - 1, -1, -1 ):
      labelStack.append( mplsLabelOperation.labelStack( i ) )
   return labelStack

def labelStackToString( labelStack ):
   """Takes an MPLS label stack, represented as either a list or tuple, and
      converts it into a string representation. For example, [ 10, 20, 30 ]
      will be converted to "10, 20, 30"."""

   assert ( isinstance( labelStack, list ) or isinstance( labelStack, tuple ) or \
      isinstance( labelStack, CliModel._TypedList ) ) # pylint: disable-msg=W0212
   return ' '.join( map( str, labelStack ) )

def constructMplsHeader( labelStack, mplsTtl=255, mplsCos=None, tc=None,
                         setBos=True, flowLabelPresent=False ):
   """Construct the MPLS header for the given label stack
   labelStack: frame order list of labels, i.e. [ Top, ..., ..., Bottom ]
   mplsTtl: should be copied from the IP frame, or IP TTL - 1.  Default 255.
   mplsCos / tc: At most one of mplsCos / tc may be specified.  If tc is specified,
                 is it converted to an MPLS COS value using a static map.
   """
   assert labelStack
   assert mplsCos is None or tc is None
   for label in labelStack:
      assert label != IMP_NULL, \
            f'labelStack should not contain IMP_NULL {labelStack}'

   if tc is not None:
      cosValue = tcToMplsCosStatic( tc )
   elif mplsCos is not None:
      cosValue = mplsCos
   else:
      cosValue = 0

   mplsHeader = Tac.newInstance( 'Arnet::Pkt' )
   mplsHeader.newSharedHeadData = MplsHdrSize * len( labelStack )
   bosIndex = len( labelStack ) - 1
   for ( index, label ) in enumerate( labelStack ):
      offset = MplsHdrSize * index
      mplsHdrWrapper = Tac.newInstance( 'Arnet::MplsHdrWrapper', mplsHeader, offset )
      mplsHdrWrapper.label = label
      mplsHdrWrapper.ttl = 0 if index and labelStack[ index - 1 ] == ELI else mplsTtl
      mplsHdrWrapper.cos = cosValue
      if setBos:
         mplsHdrWrapper.bos = ( index == bosIndex )
         if mplsHdrWrapper.bos and flowLabelPresent:
            mplsHdrWrapper.ttl = 0
      else:
         mplsHdrWrapper.bos = False

   return mplsHeader.stringValue

def tcToMplsCosStatic( trafficClass ):
   # BUG137916 - Use cos-to-tc map to find the TC first, instead of 1:1 mapping
   tcToMplsCos = {
      0 : 1,
      1 : 0,
   }
   # Unspecified values in tcToMplsCos are identity
   return tcToMplsCos.get( trafficClass, trafficClass )
