#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from ArnetModel import IpGenericAddress
from Arnet import IpGenAddr
from CliModel import Model
from CliModel import ( Dict,
                       GeneratorDict,
                       Bool,
                       Enum,
                       List,
                       Int,
                       Submodel )
from TableOutput import ( createTable,
                          Format )
from Toggles.SrTePolicyToggleLib import toggleReplicationSegmentEnabled
from CliPlugin.SrTePolicyLibModel import ( SrTeSegmentListVia,
                                           invalidReasonEnum )

summaryHeader = ( 'Source', 'Active', 'Valid', 'Invalid', 'total' )
sharedHeader = ( 'Rep ID', 'Active', 'Valid', 'Invalid', 'total' )
nonSharedHeader = ( 'Root', 'Tree ID', 'Path Instance', 'Active', 'Valid', 'Invalid',
                 'total' )

SL_STATE_VALID = 'valid'
SL_STATE_INVALID = 'invalid'

def protocolStrMap( protocol ):
   enumStrMap = { 'static': 'Static',
                  'pcep': 'PCEP' }
   return enumStrMap[ protocol ]

class ReplicationSegmentStateModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   active = Int( help='Number of replication segments in active state', default=0 )
   valid = Int( help='Number of replication segments in valid state', default=0 )
   invalid = Int( help='Number of replication segments in invalid state', default=0 )

class ReplicationSegmentSharedStatModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   replicationIds = Dict( keyType=int,
                          valueType=ReplicationSegmentStateModel,
                   help='Replication segments, keyed by their replication id' )

   def render( self ):
      print( '\nShared replication segment statistics:' )
      headings = sharedHeader
      table = createTable( headings, tableWidth=120 )
      fmt = Format( justify='right' )
      table.formatColumns( *( ( fmt, ) * len( headings ) ) )
      for replicationId in sorted( self.replicationIds ):
         replicationSegments = self.replicationIds[ replicationId ]
         table.newRow( replicationId,
                       replicationSegments.active,
                       replicationSegments.valid,
                       replicationSegments.invalid,
                       replicationSegments.active + \
                       replicationSegments.valid + \
                       replicationSegments.invalid )
      print( table.output() )

class ReplicationSegmentPath( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   paths = Dict( keyType=int, valueType=ReplicationSegmentStateModel,
                 help='Replication segments, keyed by their path instance number')

class ReplicationSegmentTree( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   trees = Dict( keyType=int, valueType=ReplicationSegmentPath,
                 help='Replication segments, keyed by their tree ID')

class ReplicationSegmentRoot( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   roots = Dict( keyType=IpGenericAddress, valueType=ReplicationSegmentTree,
                 help='Replication segments, keyed by their root address')

   def render( self ):
      print( '\nNon-shared replication segment statistics:' )
      headings = nonSharedHeader
      table = createTable( headings, tableWidth=120 )
      fmt = Format( justify='right' )
      table.formatColumns( *( ( fmt, ) * len( headings ) ) )
      for root, treeDict in sorted( self.roots.items() ):
         for treeId, pathDict in sorted( treeDict.trees.items() ):
            for pathId, status in sorted( pathDict.paths.items() ):
               table.newRow( root, treeId, pathId, status.active, status.valid,
                             status.invalid,
                             status.active + status.valid + status.invalid )
      print( table.output() )

class ReplicationSegmentStatisticsModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   sources = Dict( keyType=str,
                   valueType=ReplicationSegmentStateModel,
                   help='Replication segments, keyed by their source' )
   totalStat = Submodel( valueType=ReplicationSegmentStateModel,
                         help='Total replication segment statistics for all '
                         'sources' )

   def render( self ):
      print( '\nPer source protocol statistics:' )
      headings = summaryHeader
      table = createTable( headings, tableWidth=120 )
      fmt = Format( justify='right' )
      table.formatColumns( *( ( fmt, ) * len( headings ) ) )
      for src in self.sources:
         source = self.sources[ src ]
         table.newRow( protocolStrMap( src ),
                       source.active,
                       source.valid,
                       source.invalid,
                       source.active + \
                       source.valid + \
                       source.invalid )
      total = self.totalStat
      table.newRow( 'Total',
                    total.active,
                    total.valid,
                    total.invalid,
                    total.active + \
                    total.valid + \
                    total.invalid )
      print( table.output() )

class ReplicationSegmentSummaryModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   protocolStats = Submodel( valueType=ReplicationSegmentStatisticsModel,
                             help='Per source protocol '
                             'Routing Traffic Engineering Replication'
                             ' segment statistics' )
   sharedStatistics = Submodel( valueType=ReplicationSegmentSharedStatModel,
                                help='Segment Routing Traffic Engineering'
                                ' shared Replication Segment related '
                                'statistics' )
   nonSharedStatistics = Submodel( valueType=ReplicationSegmentRoot,
                                   help='Segment Routing Traffic Engineering'
                                   ' not shared Replication Segment related'
                                   ' statistics' )
   replicationSegmentsRsidConflictCount = Int( default=0,
                                               help='Number of replication segment'
                                               ' with RSID conflict' )

   def render( self ):
      self.protocolStats.render()
      self.sharedStatistics.render()
      self.nonSharedStatistics.render()
      print( '\nNumber of replication segments with RSID conflicts:',
              self.replicationSegmentsRsidConflictCount )
      print( '\n' )

class ReplicationSegmentSummaryVrfModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   vrfs = Dict( keyType=str, valueType=ReplicationSegmentSummaryModel,
                help='Per VRF Segment Routing Replication Segments' )

   def render( self ):
      for vrfSummary in self.vrfs.values():
         vrfSummary.render()

class SrP2mpReplicationSid( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   replicationSidType = Enum( values=( 'mpls', 'ipv6' ),
                          default='mpls',
                          help='Replicaiton Segment Identifier Type' )
   mplsLabelSid = Int( help='Replication Segment Identifier in the form of an MPLS '
                       'label',
                       optional=True )

# Common models shared by all SrTeP2mp show commands
class SrP2mpSegment( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   nexthopAddr = IpGenericAddress( help='Segment identifier in the form of '
                                   'next hop IP address', optional=True )
   mplsLabelSid = Int( help='Segment identifier in the form of an MPLS label',
                       optional=True )

class SrP2mpRepSegSegmentList( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   segmentListId = Int( help='Internal ID of the segment list '
                        'for valid candidate paths',
                        optional=True )
   segmentListValid = Bool( help='Validity of the segment list '
                            'for valid candidate paths',
                            optional=True )
   segments = List( valueType=SrP2mpSegment,
                    help='Segments that make up the segment list' )
   vias = List( valueType=SrTeSegmentListVia,
               help='List of next hops for the primary path of the segment list, '
               'if the segment list is valid' )
   backupVias = List( valueType=SrTeSegmentListVia,
                     help='List of next hops for the backup path '
                     'of the segment list, if the segment list is valid' )
   invalidReason = invalidReasonEnum

   def stateStr( self ):
      if self.segmentListValid is None:
         return ''
      if self.segmentListValid:
         return ' State: ' + SL_STATE_VALID
      else:
         return ' State: ' + SL_STATE_INVALID

   def segmentStr( self ):
      segmentType = 'segment'
      if len( self.segments ) >= 1:
         if self.segments[ 0 ].nexthopAddr:
            segmentType = 'next hop IP address'
         else:
            segmentType = 'MPLS label'
      return segmentType

   def invalidReasonStr( self ):
      topSegmentType = self.segmentStr()
      invalidReasonMap = {
         'unresolvedTopLabel': 'Top ' + topSegmentType + ' is not resolved',
         'segmentListMissing': 'Segment list is not present',
         'noResolvedLabels': 'Resolved segment list has no segments',
         'resolvedLabelsExceedPlatformMsd':'Resolved segment list exceeds Platform' +
         ' Limit',
      }
      return invalidReasonMap[ self.invalidReason ]

   def render( self ):
      print( 'Segment List:%s' % ( self.stateStr()  ) )
      # For now all segments are mpls labels
      lblStkStr = '\tLabel Stack: ['
      lblStkStr += ' '.join( str( segment.mplsLabelSid )
                             for segment in self.segments
                             if segment.nexthopAddr is None or
                             segment.nexthopAddr == IpGenAddr() )
      lblStkStr += ']'
      nexthopAddr = ''
      if len( self.segments ) >= 1 and self.segments[ 0 ].nexthopAddr:
         nexthopAddr = ', Next hop: %s' % (
                                         self.segments[ 0 ].nexthopAddr.stringValue )
      lblStkStr += nexthopAddr
      print( lblStkStr )
      if self.invalidReason:
         print( '\tInvalid Reason: %s' % ( self.invalidReasonStr() ) )
      for via in self.vias:
         viaStr = '\tResolved Label Stack: ['
         viaStr += ' '.join( '%u' % label for label in via.mplsLabels )
         viaStr += ( f'], Next hop: {via.nexthop}, Interface: '
                     f'{via.interface.stringValue}' )
         print( viaStr )
      for backupVia in self.backupVias:
         backupViaStr = '\tBackup Resolved Label Stack: ['
         backupViaStr += \
                  ' '.join( '%u' % label for label in backupVia.mplsLabels )
         backupViaStr += (
            f'], Next hop: {backupVia.nexthop}, Interface: '
            f'{backupVia.interface.stringValue}' )
         print( backupViaStr )

class SrP2mpCandidates( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   source = Enum( values=( 'static', 'pcep'), help='Protocol source'
                  ' of replication segment' )
   state = Enum( values=( 'active', 'valid', 'invalid' ),
                 help='State of the replication segment' )
   replicationSid = Submodel( valueType=SrP2mpReplicationSid,
                              help='Mpls label to identify replication segment',
                              optional=True )
   root = IpGenericAddress( help='Node address of root' )
   treeId = Int( help='Tree identifier' )
   pathInstance = Int( help='Path Instance number' )
   nodeAddress = IpGenericAddress( help='Address of replication node' )
   segmentLists = List( valueType=SrP2mpRepSegSegmentList,
                        help= 'Replication segments information ' )
   def renderRepSeg( self, repId ):
      print( 'Source: %s' % ( self.source) )
      print( 'Replication ID: %u, Node address: %s'
             %( repId, self.nodeAddress ) )
      print( 'Root address: %s, Tree ID: %u' # pylint: disable=bad-string-format-type
             %( self.root, self.treeId ) )
      print( 'Path instance: %u' % ( self.pathInstance ) )
      print( 'State: %s' % ( self.state ) )
      if self.replicationSid is not None:
         print( 'Replication SID: %u' % ( self.replicationSid.mplsLabelSid ) )
      for sl in self.segmentLists:
         sl.render()

class SrP2mpCandidatesModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   replicationSegments = GeneratorDict( keyType=int, valueType=SrP2mpCandidates,
                                        help='Replication segments, keyed by their'
                                        ' replication ID')

   def render( self ):
      for repId, repSeg in sorted( self.replicationSegments ):
         repSeg.renderRepSeg( repId )

class SrP2mpCandidatesVrfModel( Model ):
   __public__ = toggleReplicationSegmentEnabled()
   vrfs = Dict( keyType=str, valueType=SrP2mpCandidatesModel,
                help='A mapping of VRF name to its replication segments')

   def render( self ):
      for vrfCandidates in self.vrfs.values():
         vrfCandidates.render()
