#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Ark
import ArnetModel
from ArPyUtils import naturalOrderKey
from CliModel import Bool
from CliModel import Dict
from CliModel import Enum
from CliModel import Float
from CliModel import Int
from CliModel import List
from CliModel import Model
from CliModel import Str
from CliModel import Submodel
from CliPlugin.TeCli import adminGroupToStr, adminGroupDecimalListToDict
from CliPlugin.CspfShowPathModel import (
   excludeAdminGroupInt,
   excludeAdminGroupsExtendedList,
   includeAllAdminGroupInt,
   includeAllAdminGroupsExtendedList,
   includeAnyAdminGroupInt,
   includeAnyAdminGroupsExtendedList,
)
import IntfModels
import TableOutput
from Toggles import TeToggleLib
import Tac

class FlexAlgoViaModel( Model ):
   nexthop = ArnetModel.IpGenericAddress( help="Next hop IP address" )
   intf = IntfModels.Interface( help="L3 interface of next hop" )

class FlexAlgoPathConstraintModel( Model ):
   metricType = Enum( values=( 'igp', 'minDelay', 'te' ), help="Metric type" )
   excludeSrlg = List( valueType=int, optional=True,
                       help="List of SRLG IDs excluded" )
   # Hidden attribute to render srlg name
   _srlgIdToNameMap = Dict( keyType=int, valueType=str, valueOptional=True,
                            help="Map of srlgId to srlgName", optional=True )
   includeAllAdminGroup = includeAllAdminGroupInt()
   includeAllAdminGroupsExtended = includeAllAdminGroupsExtendedList()
   includeAnyAdminGroup = includeAnyAdminGroupInt()
   includeAnyAdminGroupsExtended = includeAnyAdminGroupsExtendedList()
   excludeAdminGroup = excludeAdminGroupInt()
   excludeAdminGroupsExtended = excludeAdminGroupsExtendedList()

   def getSrlgName( self, srlgId ):
      srlgName = self._srlgIdToNameMap.get( srlgId )
      if srlgName is None:
         return str( srlgId )
      else:
         return srlgName + " (" + str( srlgId ) + ")"

   def renderText( self, algoName ):
      pathConstraintTxt = "Path constraints: "
      print( pathConstraintTxt + "algo " + algoName )
      indent = " " * len( pathConstraintTxt )
      metricTypeStr = 'IGP'
      if self.metricType == 'minDelay':
         metricTypeStr = 'MIN-DELAY'
      if self.metricType == 'te':
         metricTypeStr = 'TE'
      print( indent + "metric type " + metricTypeStr )
      if self.excludeSrlg:
         for srlgId in self.excludeSrlg:
            print( indent + "srlg exclude " + self.getSrlgName( srlgId ) )

      def renderAdminGroupConstraint( adminGroupVal, agDecimalList, attrString ):
         if adminGroupVal or agDecimalList:
            adminGroup = adminGroupVal
            if TeToggleLib.toggleExtendedAdminGroupEnabled():
               adminGroup = adminGroupDecimalListToDict( agDecimalList )
            print( f"{indent}administrative-group {attrString} "
                   f"{adminGroupToStr(adminGroup)}" )

      renderAdminGroupConstraint( self.includeAllAdminGroup,
            self.includeAllAdminGroupsExtended, "include all" )
      renderAdminGroupConstraint( self.includeAnyAdminGroup,
            self.includeAnyAdminGroupsExtended, "include any" )
      renderAdminGroupConstraint( self.excludeAdminGroup,
            self.excludeAdminGroupsExtended, "exclude" )

class FlexAlgoPathEntryDetailModel( Model ):
   pathId = Int( help="Path ID number" )
   constraint = Submodel( valueType=FlexAlgoPathConstraintModel,
                          help="Path constraint" )
   refreshReqSeq = Int( help="Request sequence number" )
   refreshRespSeq = Int( help="Response sequence number" )
   changeCount = Int( help="Number of times path updated" )
   lastUpdatedTime = Float( help="UTC timestamp of the last path update" )
   metric = Int( help="Path metric", optional=True )

   def renderText( self, algoName ):
      print( 'Path ID: ' + str( self.pathId ) )
      self.constraint.renderText( algoName )
      print( 'Request sequence number: ' + str( self.refreshReqSeq ) )
      print( 'Response sequence number: ' + str( self.refreshRespSeq ) )
      print( 'Number of times path updated: ' + str( self.changeCount ) )
      print( 'Last updated: ' + Ark.timestampToStr( self.lastUpdatedTime,
                                                    now=Tac.utcNow() ) )
      if self.metric is not None:
         print( 'Metric: ' + str( self.metric ) )

class FlexAlgoPathEntryModel( Model ):
   algoName = Str( help="Algorithm name", optional=True )
   vias = List( FlexAlgoViaModel, help="List of flex algo vias" )
   details = Submodel( valueType=FlexAlgoPathEntryDetailModel,
                       help="Detailed path information", optional=True )

   def viaGenerator( self ):
      for via in self.vias:
         yield via.nexthop, via.intf.stringValue

class FlexAlgoPathDestModel( Model ):
   hostname = Str( help="Hostname", optional=True )
   paths = Dict( keyType=int, valueType=FlexAlgoPathEntryModel,
                 help="A mapping of an algorithm to the corresponding "
                 "flex algo path" )

class FlexAlgoPathTopoIdModel( Model ):
   destinations = Dict( keyType=str, valueType=FlexAlgoPathDestModel,
                        help="A mapping of path destination to all "
                        "flex algo paths for that destination" )

   def destinationIterator( self ):
      '''
      yields (dst, dstName) tuple
      where dstName is the hostname
      '''
      dstDstnameTuple = []
      for dst in self.destinations:
         dstName = self.destinations[ dst ].hostname or dst
         dstDstnameTuple.append( ( dst, dstName ) )

      def sortByDestNameAndDest( elem ):
         ''' First sort by destination name (hostname)
             and if they are same then sort by destination
         '''
         return ( naturalOrderKey( elem[ 1 ] ), elem[ 0 ] )

      for dst, dstName in sorted( dstDstnameTuple, key=sortByDestNameAndDest ):
         yield dst, dstName

   def renderSummaryOutput( self ):
      headings = ( ( 'Destination', 'l' ), ( 'Algorithm', 'l' ),
                   ( 'Next Hop', 'l' ), ( 'Interface', 'l' ) )
      table = TableOutput.createTable( headings )
      for dst, dstName in self.destinationIterator():
         for algo, pathEntry in sorted( self.destinations[ dst ].paths.items() ):
            firstVia = True
            algoName = algo
            if pathEntry.algoName:
               algoName = pathEntry.algoName
            for nh, intf in pathEntry.viaGenerator():
               if firstVia:
                  table.newRow( dstName, algoName, nh, intf )
                  firstVia = False
                  continue
               table.newRow( "", "", nh, intf )
            if firstVia:
               # This means there are no vias for this destination
               # Just print dst and algo
               table.newRow( dstName, algoName, "", "" )
            # Print empty line between paths of different algos
            table.newRow( "", "", "", "" )
      print( table.output() )

   def renderDetailOutput( self ):
      for dst, dstName in self.destinationIterator():
         for algo, pathEntry in sorted( self.destinations[ dst ].paths.items() ):
            print( 'Destination: ' + dstName )
            algoName = pathEntry.algoName or str( algo )
            pathEntry.details.renderText( algoName )
            headings = ( ( 'Next Hop', 'l' ), ( 'Interface', 'l' ) )
            table = TableOutput.createTable( headings )

            colFormat = TableOutput.Format( justify="left" )
            colFormat.padLimitIs( True )
            table.formatColumns( *[ colFormat ] * len( headings ) )

            for nh, intf in pathEntry.viaGenerator():
               table.newRow( nh, intf )
            print( table.output() )

   def renderText( self, detailCmd ):
      if detailCmd:
         self.renderDetailOutput()
      else:
         self.renderSummaryOutput()

class FlexAlgoPathAfModel( Model ):
   topologies = Dict( keyType=int, valueType=FlexAlgoPathTopoIdModel,
                      help="A mapping of topology ID to all flex algo paths" )

   def renderText( self, detailCmd ):
      for topoId, topoModel in self.topologies.items():
         if topoModel.destinations:
            # pylint: disable-next=consider-using-f-string
            print( "Topology ID: Level-%d" % topoId )
            topoModel.renderText( detailCmd )

class FlexAlgoPathVrfModel( Model ):
   v4Info = Submodel( valueType=FlexAlgoPathAfModel, optional=True,
                      help="Flex algo path information for IPv4 address family" )
   v6Info = Submodel( valueType=FlexAlgoPathAfModel, optional=True,
                      help="Flex algo path information for IPv6 address family" )

   def renderText( self, detailCmd ):
      if self.v4Info and self.v4Info.topologies:
         print( "Flex algo paths for IPv4 address family" )
         self.v4Info.renderText( detailCmd )
      if self.v6Info and self.v6Info.topologies:
         print( "Flex algo paths for IPv6 address family" )
         self.v6Info.renderText( detailCmd )

class FlexAlgoPathModel( Model ):
   vrfs = Dict( keyType=str, valueType=FlexAlgoPathVrfModel,
                help="A mapping between VRF and information of all flex algo paths"
                " in that VRF" )
   # Hidden attribute
   _detailCmd = Bool( default=False, help="Detail command output requested" )

   def render( self ):
      for key in sorted( self.vrfs ):
         self.vrfs[ key ].renderText( self._detailCmd )
