#!/usr/bin/env python3
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
import Tracing
from CliModel import DeferredModel, Model, Submodel, Bool, Int, Float, Str, Enum
from CliModel import GeneratorDict, List, Dict
from IntfModels import Interface
from ArnetModel import IpGenericAddress, IpGenericPrefix
import datetime
import Arnet
import PimModel
from PimLib import AddressFamily

traceHandle = Tracing.defaultTraceHandle()
t0 = traceHandle.trace0
t1 = traceHandle.trace1
t2 = traceHandle.trace2
t3 = traceHandle.trace3

def iPrint( indent, obj ): 
   """Prints an indented string"""

   string = ' ' * indent + "%s" % obj
   print( string )

def timeElapsedString( utcTime ):
   """Returns a human readable string that represents the time elapsed since the
   (UTC) time passed in"""

   currentTime = Tac.utcNow()
   delta = currentTime - utcTime
   td = datetime.timedelta( seconds=int( delta ) )

   if td > datetime.timedelta( days=1 ) :
      # If time elapsed is > 1 day display the time in XdXXh format
      return str( td.days ) + "d" + str( td.seconds // 3600 ) + "h"
   else:
      # Display time elapsed in HH:MM:SS format
      return str( td )

class PimRpCandidate( Model ):
   mode = Enum( help="PIM mode",
                values=( 'Sparse',
                         'Bidir',
                         'Sparse/Bidir',
                         'None' ), optional=True )
   rpType = Str( help="Type of RP candidate" )
   rpAddress = IpGenericAddress( help="IP address of RP" )
   priority = Int( help="Rp Priority" )
   hashMaskLen = Int( help="Hash Mask Length" )
   creationTime = Float( help="UTC time at which the entry was created" )
   #Following used only for dynamic RPs
   expires = Float( help="Number of seconds after which the entry will expire",
                       optional=True )
   holdTime = Int( help="Number of seconds for which this entry is set to be held",
                   optional=True )
   #Following is used by static rp
   override = Bool( help="Entry is set to override BSR provided RP",
                    optional=True )
   hashMaskValue = Int( help="Hash Mask Length", optional=True )

   # pylint: disable-msg=W0221
   def render( self, detail=False ):
      indent = 2
      iPrint( indent, 'RP: %s' % self.rpAddress )
      indent += 2
      if detail:
         iPrint( indent, 'Type:        %s' % self.rpType )
         iPrint( indent, 'Priority:    %d' % self.priority )
         iPrint( indent, 'HashMaskLen: %d' % self.hashMaskLen )
         iPrint( indent, 'Uptime:      %s' % timeElapsedString( self.creationTime ) )
      if self.rpType == 'staticRp':
         if detail:
            iPrint( indent, 'Expires:     never' )
            iPrint( indent, 'Override:    %s' % \
                       ( 'True' if self.override else 'False' ) )
         else:
            iPrint( indent,
                    'Uptime: %s, Expires: never, Priority: %d, Override: %s' % \
                       ( timeElapsedString( self.creationTime ), self.priority,
                         'True' if self.override else 'False' ) )
      elif self.rpType == 'bsrRp':
         td = datetime.timedelta( seconds=int( self.expires ) )
         if detail:
            iPrint( indent, 'Expires:     %s' % str( td ) )
            iPrint( indent, 'Holdtime:    %d' % self.holdTime ) 
         else:
            iPrint( indent, 'Uptime: %s, Expires: %s, Priority: %d' % \
                       ( timeElapsedString( self.creationTime ),
                         str( td ),
                         self.priority ) )

   def renderForRpHashModulo( self, detail=False ):
      indent = 2
      iPrint( indent, 'RP: %s' % self.rpAddress )
      indent += 2
      expireStr = 'never'
      if self.rpType == 'bsrRp':
         expireStr = str( datetime.timedelta( seconds=int( self.expires ) ) )

      iPrint( indent,
              'Uptime: %s, Expires: %s' % \
              ( timeElapsedString( self.creationTime ), expireStr ) )
      
   def renderForRpHash( self, detail=False ):
      indent = 2
      iPrint( indent, 'RP: %s' % self.rpAddress )
      indent += 2
      if detail:
         iPrint( indent, 'Type:          %s' % self.rpType )
         iPrint( indent, 'Priority:      %d' % self.priority )
         iPrint( indent, 'HashMaskLen:   %d' % self.hashMaskLen )
         iPrint( indent, 'HashMaskValue: %d' % self.hashMasValue )
         iPrint( indent, 'Uptime:        %s' % \
                    timeElapsedString( self.creationTime ) )
      if self.rpType == 'staticRp':
         if detail:
            iPrint( indent, 'Expires:      never' )
            iPrint( indent, 'Override:     %s' % \
                       ( 'True' if self.override else 'False' ) )
         else:
            iPrint( indent,
                    'Uptime: %s, Expires: never, Priority: %d, HashMaskLen: %d'
                    ', HashMaskValue: %d, Override: %s' % \
                       ( timeElapsedString( self.creationTime ), self.priority,
                         self.hashMaskLen, self.hashMaskValue,
                         'True' if self.override else 'False' ) )
      elif self.rpType == 'bsrRp':
         td = datetime.timedelta( seconds=int( self.expires ) )
         if detail:
            iPrint( indent, 'Expires:      %s' % str( td ) )
            iPrint( indent, 'Holdtime:     %d' % self.holdTime ) 
         else:
            iPrint( indent, 'Uptime: %s, Expires: %s, Priority: %d'
                    ', HashMaskLen: %d, HashMaskValue: %d' % \
                       ( timeElapsedString( self.creationTime ),
                         str( td ),
                         self.priority, self.hashMaskLen, self.hashMaskValue ) )

   def convertPimMode( self, mode ):
      typeDict = {
         'modePimSm'         : 'Sparse',
         'modePimBidir'      : 'Bidir',
         'modePimSmAndBidir' : 'Sparse/Bidir',
         'modePimNone'       : 'None'
      }
      return typeDict[ mode ]

   def initFromTacc( self, rpType, rp, hmaskVal=0 ):
      self.mode = self.convertPimMode( rp.pimMode )
      self.rpType = rpType
      self.rpAddress = rp.ipAddr
      self.priority = rp.priority
      self.hashMaskLen = rp.hashMaskLen
      self.creationTime = rp.setupTime + Tac.utcNow() - Tac.now()
      if rpType == 'bsrRp':
         self.expires = rp.expires - Tac.now()
         self.holdTime = rp.holdTime
      elif rpType == 'staticRp':
         self.override = rp.override
      self.hashMaskValue = hmaskVal

class PimRpCandidateSet( Model ):
   mode = Enum( help="PIM mode",
                values=( 'Sparse',
                         'Bidir',
                         'Sparse/Bidir',
                         'None' ), optional=True )
   prefix = IpGenericPrefix( help="Group Prefix" )
   crp = GeneratorDict( help="Map of candidate RP information for this RP",
                        keyType=IpGenericAddress, valueType=PimRpCandidate )

   def render( self ):
      modeStr = ' (bidirectional)' if self.mode == 'Bidir' else ''
      iPrint( 0, f'Group: {self.prefix}{modeStr}' )

   def convertPimMode( self, mode ):
      typeDict = {
         'modePimSm'         : 'Sparse',
         'modePimBidir'      : 'Bidir',
         'modePimSmAndBidir' : 'Sparse/Bidir',
         'modePimNone'       : 'None'
      }
      return typeDict[ mode ]
   
   def generateCrp( self, staticRpSet, bsrRpSet, pimMode ):
      if staticRpSet:
         srps = list( staticRpSet.rp )
         for rp in srps:
            staticRp = staticRpSet.rp.get( rp )
            if staticRp:
               model = PimRpCandidate()
               model.initFromTacc( 'staticRp', staticRp )
               yield rp, model

      if bsrRpSet:
         bsrRps = list( bsrRpSet.rp )
         for rp in bsrRps:
            bsrRp = bsrRpSet.rp.get( rp )
            if bsrRp and self.convertPimMode( bsrRp.pimMode ) == self.mode:
               model = PimRpCandidate()
               model.initFromTacc( 'bsrRp', bsrRp )
               yield rp, model

   def initFromTacc( self, prefix, pimMode=None, staticRpSet=None, bsrRpSet=None ):
      if pimMode:
         self.mode = pimMode
      else:
         self.mode = 'Sparse'
         rp = None
         if staticRpSet:
            for rp in staticRpSet.rp.values():
               self.mode = self.convertPimMode( rp.pimMode )
               break
         if not rp and bsrRpSet:
            for rp in bsrRpSet.rp.values(): 
               self.mode = self.convertPimMode( rp.pimMode )
               break
 
      self.prefix = prefix
      self.crp = self.generateCrp( staticRpSet, bsrRpSet, pimMode )

class PimRpCandidates( Model ):
   crpSet = GeneratorDict( help="Map of Candidate RP Set belonging to "
                           "this group prefix", keyType=IpGenericPrefix,
                           valueType=PimRpCandidateSet )
   detail = Bool( help='ShowDetails' )
 
   def convertPimMode( self, mode ):
      typeDict = {
         'modePimSm'         : 'Sparse',
         'modePimBidir'      : 'Bidir',
         'modePimSmAndBidir' : 'Sparse/Bidir',
         'modePimNone'       : 'None'
      }
      return typeDict[ mode ]
   
   def generateCrpSet( self, pimMode, staticStatus=None, bsrStatus=None,
                       groupPrefix=None ):
      staticPrefixes = []
      bsrPrefixes = []

      if staticStatus and staticStatus.rpSetType == 'staticRp':
         staticPrefixes = list( staticStatus.crpSet )

      if bsrStatus and bsrStatus.rpSetType == 'bsrRp' :
         bsrPrefixes = list( bsrStatus.crpSet )

      def checkBsrRpSetPimMode( bsrRpSet ):
         modes = []
         if bsrRpSet:
            for rp in bsrRpSet.rp.values():
               mode  = self.convertPimMode( rp.pimMode )
               if not mode in modes:
                  modes.append( mode )
         return modes


      if groupPrefix:
         prefix = groupPrefix
         if prefix in staticPrefixes and prefix in bsrPrefixes:
            model = PimRpCandidateSet()
            model.initFromTacc( prefix, pimMode,
                                staticRpSet=staticStatus.crpSet[ prefix ],
                                bsrRpSet=bsrStatus.crpSet.get( prefix ) )
            yield prefix, model
         elif prefix in staticPrefixes:
            model = PimRpCandidateSet()
            model.initFromTacc( prefix, pimMode,
                                staticRpSet=staticStatus.crpSet[ prefix ] )
            yield prefix, model
         elif prefix in bsrPrefixes: 
            bsrRpSet = bsrStatus.crpSet.get( prefix )

            if not pimMode or pimMode in checkBsrRpSetPimMode( bsrRpSet ):
               model = PimRpCandidateSet()
               model.initFromTacc( prefix, pimMode,
                                   bsrRpSet=bsrRpSet )
               yield prefix, model
      else:
         for prefix in staticPrefixes:
            model = PimRpCandidateSet()
            if prefix in bsrPrefixes:
               model.initFromTacc( prefix, pimMode,
                                   staticRpSet=staticStatus.crpSet[ prefix ],
                                   bsrRpSet=bsrStatus.crpSet.get( prefix ) )
               bsrPrefixes.remove( prefix )
            else:
               model.initFromTacc( prefix, pimMode,
                                   staticRpSet=staticStatus.crpSet[ prefix ] )
            yield prefix, model

         for prefix in bsrPrefixes:
            bsrRpSet = bsrStatus.crpSet.get( prefix )
            if not pimMode or pimMode in checkBsrRpSetPimMode( bsrRpSet ):
               model = PimRpCandidateSet()
               model.initFromTacc( prefix, pimMode,
                                   bsrRpSet=bsrRpSet )
               yield prefix, model


   def generate( self, static, dynamic, pimMode=None,
         groupPrefix=None, detail=False ):
      self.detail = bool( detail )
      self.crpSet = self.generateCrpSet( pimMode, static, dynamic, groupPrefix )

   def render( self ):
      for _, crpSet in self.crpSet:
         crpSet.render()
         for _, crp in crpSet.crp:
            crp.render( detail=self.detail )

class PimRpCandidatesAll( Model ):
   __revision__ = 2
   sparseMode = Submodel( help="RP Candidate Set for Sparse-mode PIM",
                          valueType=PimRpCandidates, optional=True )
   bidirectional = Submodel( help="RP Candidate Set for Bidirectional PIM",
                             valueType=PimRpCandidates, optional=True )

   def initialize( self, crpSets ):
      if 'modePimSm' in crpSets:
         self.sparseMode = crpSets[ 'modePimSm' ]
      if 'modePimBidir' in crpSets:
         self.bidirectional = crpSets[ 'modePimBidir' ]

   def render( self ):
      if self.sparseMode:
         self.sparseMode.render()
      if self.bidirectional:
         self.bidirectional.render()

class IpMroutesAll( DeferredModel ):
   __revision__ = 2
   sparseMode = Submodel( help='IP Mroute for Sparse-mode PIM',
         valueType=PimModel.GroupSms )
   bidirectional = Submodel( help='IP Mroute for Bidirectional PIM',
         valueType=PimModel.Groups )

class IpMroutesAllVrfs( DeferredModel ):
   vrfs = Dict( keyType=str, valueType=IpMroutesAll,
                help='Multicast routes for all VRFs' )


class MrouteCountAll( DeferredModel ):
   __revision__ = 2
   sparseMode = Submodel( help='count for Sparse-mode PIM',
         valueType=PimModel.MrouteCount )
   bidirectional = Submodel( help='count for bidirectional PIM',
         valueType=PimModel.MrouteBidirCount )


class MrouteInterfaceAll( DeferredModel ):
   __revision__ = 2
   sparseMode = Submodel( help='IP Mroute Interfaces for Sparse-mode PIM',
         valueType = PimModel.MrouteSmInterfaces )
   bidirectional = Submodel( help='IP Mroute Interfaces for Bidirectional PIM',
         valueType = PimModel.MrouteInterfaceDetails )


class PimRpHash( Model ):
   mode = Enum( help="PIM mode",
                values=( "modePimNone","modePimSm", "modePimBidir" ), optional=True )
   prefix = IpGenericPrefix( help="Group Prefix" )
   rp = IpGenericAddress( help="Rendezvous Point Router's IP address",
                      optional=True )
   rpPriority = Int( help="RP Priority", optional=True )
   rpHashMaskLen = Int( help="Hash Mask Length", optional=True )
   crp = GeneratorDict( help="Map of candidate RP info related to the RP Hash",
                        keyType=IpGenericAddress, valueType=PimRpCandidate,
                        optional=True )

   detail = Bool( help="Show details" )
   rpHashAlgorithm = Enum( help="Rp Hash Algorithm",
                           values=( "rpHashAlgorithmDefault",
                                    "rpHashAlgorithmModulo" ), 
                           optional=True )

   def generateCrp( self, staticCompute, bsrCompute, group ):
      if group.af == AddressFamily.ipv4:
         prefix = Arnet.IpGenPrefix( group.stringValue + '/32' )
      else:
         prefix = Arnet.IpGenPrefix( group.stringValue + '/128' )
      if staticCompute:
         sCrpSet = staticCompute.doLongestPrefixGroupMatch( prefix )
         if sCrpSet:
            for crp in sCrpSet.rp:
               model = PimRpCandidate()
               rpInfo = sCrpSet.rp[ crp ]
               model.initFromTacc( 'staticRp', rpInfo,
                     staticCompute.doComputeHashMask( rpInfo.pimMode,
                                                      group,
                                                      rpInfo.hashMaskLen,
                                                      crp ) )
               yield crp, model

      if bsrCompute:
         dCrpSet = bsrCompute.doLongestPrefixGroupMatch( prefix )
         if dCrpSet:
            for crp in dCrpSet.rp:
               model = PimRpCandidate()
               rpInfo = dCrpSet.rp[ crp ]
               model.initFromTacc( 'bsrRp', rpInfo,
                                   bsrCompute.doComputeHashMask( rpInfo.pimMode,
                                                                 group,
                                                                 rpInfo.hashMaskLen,
                                                                 crp ) )

               yield crp, model

   def render( self ):
      indent = 0
      numberOfRps = 0
      if self.rp:
         modeStr = ' (bidirectional)' if self.mode == 'modePimBidir' else ''
         iPrint( indent, f'RP {self.rp}{modeStr}' )
         indent += 2
         if self.detail:
            iPrint( indent, 'Pim Mode :    %s' % self.mode )
            iPrint( indent, 'Priority :    %d' % self.rpPriority )
            iPrint( indent, 'HashMaskLen : %d' % self.rpHashMaskLen )
         if self.rpHashAlgorithm == 'rpHashAlgorithmModulo':
            iPrint( indent, 'PIM v2:' )
         else:
            iPrint( indent, 'PIM v2 Hash Values:' )
         for _, rp in self.crp:
            numberOfRps += 1
            if self.rpHashAlgorithm == 'rpHashAlgorithmModulo':
               rp.renderForRpHashModulo( detail=self.detail )
            else:
               rp.renderForRpHash( detail=self.detail )
         indent = 0
         if self.rpHashAlgorithm == 'rpHashAlgorithmModulo':
            iPrint( indent, 'Number of Rps: %s' % str( numberOfRps ) )
            hashAlgorithmStr = ' Modulo '
         else:
            hashAlgorithmStr = ' Default '
         iPrint( indent, 'Hash Algorithm:%s' % hashAlgorithmStr )

   def generate( self, rpHashTable, rpStaticCompute, rpBsrCompute,
                 group, rpHashAlgorithm, pimMode = 'modePimSm', detail=False ):
      if not rpHashTable or not group:
         return

      self.rpHashAlgorithm = rpHashAlgorithm
      gAddr = Arnet.IpGenAddr( group )
      if gAddr.af == AddressFamily.ipv4:
         gPrefix = Arnet.IpGenPrefix( group + '/32' )
      else:
         gPrefix = Arnet.IpGenPrefix( group + '/128' )
      self.prefix = gPrefix
      self.detail = detail
      rpHash = rpHashTable.rpHash.get( gPrefix )
      if rpHash:
         self.mode = rpHash.pimMode
         self.rp = rpHash.rp
         self.rpPriority = rpHash.rpPriority
         self.rpHashMaskLen = rpHash.rpHashMaskLen
      else:
         crpInfoStatic = None
         crpInfoBsr = None

         if rpStaticCompute:
            crpInfoStatic = rpStaticCompute.doComputeRp( gAddr )
            if crpInfoStatic.address.af == 'ipunknown':
               crpInfoStatic = None
            elif crpInfoStatic.mode != pimMode:
               crpInfoStatic = None

         if rpBsrCompute:
            crpInfoBsr = rpBsrCompute.doComputeRp( gAddr )
            if crpInfoBsr.address.af == 'ipunknown':
               crpInfoBsr = None
            elif crpInfoBsr.mode != pimMode:
               crpInfoBsr = None

         if ( crpInfoStatic and not crpInfoBsr ) or \
                ( crpInfoStatic and crpInfoBsr and crpInfoStatic.override ):
            self.rp = crpInfoStatic.address
            self.mode = crpInfoStatic.mode
            self.rpPriority = crpInfoStatic.priority
            self.rpHashMaskLen = crpInfoStatic.hashMaskLen
         elif crpInfoBsr:
            self.rp = crpInfoBsr.address
            self.mode = crpInfoBsr.mode
            self.rpPriority = crpInfoBsr.priority
            self.rpHashMaskLen = crpInfoBsr.hashMaskLen

      self.crp = self.generateCrp( rpStaticCompute, rpBsrCompute, gAddr )

class PimRpHashAll( Model ):
   sparseMode = Submodel( help="RP Hash for Sparse-mode PIM", 
                          valueType=PimRpHash, optional=True )
   bidirectional = Submodel( help="RP Hash for Bidirectional PIM", 
                             valueType=PimRpHash, optional=True )
   def initialize( self, rpHashs ):
      if 'modePimSm' in rpHashs:
         self.sparseMode = rpHashs[ 'modePimSm' ]
      if 'modePimBidir' in rpHashs:
         self.bidirectional = rpHashs[ 'modePimBidir' ]

   def render( self ):
      if self.sparseMode:
         self.sparseMode.render()
      if self.bidirectional:
         self.bidirectional.render()

class JoinPruneFlag( Model ):
   flag = Enum( help="Type of join/prune, can be WC => (*,G), "
                     "RPT => (S,G,Rpt) or SPT => (S,G)",
                     values=( "RPT", 
                              "WC",
                              "SPT" ) )

class JoinPruneSource( Model ):

   joinSuppressed = Bool( help="The upstream join for this source is currently "
                          "suppressed" )

   joinPruneFlagList = List( help="Join/prune flags", valueType=JoinPruneFlag )


def joinPruneSourceGen( jpGroupData, source, jpType='joins' ):

   wcSource = None
   if jpGroupData.wcSourceData:
      wcSource = ( jpGroupData.wcSourceData.wcSource,
                   jpGroupData.wcSourceData.wcJoined,
                   jpGroupData.wcSourceData.wcSuppressed )

   if source:
      sources = []

      # Show only this source
      if source.isAddrZero:
         # (*,G) source
         if wcSource:
            sources.append( wcSource[ 0 ] )
      else:
         # Specific (S,G) source
         sources.append( source )

   elif jpType == 'joins' or jpType == 'prunes': # pylint: disable=consider-using-in
      if jpType == 'joins':
         # Show all joined sources
         sources = list( set( list( jpGroupData.joinedSource ) +
               list( jpGroupData.joinedRptSource ) ) )

         # Check (*,G) as well.
         if wcSource and wcSource[ 1 ]:
            sources.append( wcSource[ 0 ] )
      else:
         # Show all pruned sources
         sources = list( set( list( jpGroupData.prunedSource ) +
               list( jpGroupData.prunedRptGroupData.prunedRptSource ) ) )

         # Check (*,G) as well.
         if wcSource and not wcSource[ 1 ]:
            sources.append( wcSource[ 0 ] )
   else:
      sources = []

   for source in sources: # pylint: disable=redefined-argument-from-local

      model = None
      if jpType == 'joins' and source in jpGroupData.joinedSource:
         flag = JoinPruneFlag()
         flag.flag = 'SPT'

         model = JoinPruneSource()
         model.joinPruneFlagList.append( flag )
         model.joinSuppressed = False 
         if source in jpGroupData.joinSuppressSource:
            model.joinSuppressed = True

      if jpType == 'joins' and source in jpGroupData.joinedRptSource:
         flag = JoinPruneFlag()
         flag.flag = 'RPT'

         if not model:
            model = JoinPruneSource()
            model.joinSuppressed = False 
         # pylint: disable-msg=E1103
         model.joinPruneFlagList.append( flag )

      if jpType == 'joins' and wcSource and wcSource[ 0 ] == source and \
            wcSource[ 1 ]:
         flag = JoinPruneFlag()
         flag.flag = 'WC'

         if not model:
            model = JoinPruneSource()
            model.joinSuppressed = False 
            if wcSource[ 2 ]:
               model.joinSuppressed = True 
         # pylint: disable-msg=E1103
         model.joinPruneFlagList.append( flag )

      if jpType == 'prunes' and wcSource and wcSource[ 0 ] == source and \
            not wcSource[ 1 ]:
         flag = JoinPruneFlag()
         flag.flag = 'WC'

         if not model:
            model = JoinPruneSource()
            model.joinSuppressed = False 
         # pylint: disable-msg=E1103
         model.joinPruneFlagList.append( flag )

      if jpType == 'prunes' and source in jpGroupData.prunedSource:
         flag = JoinPruneFlag()
         flag.flag = 'SPT'

         if not model:
            model = JoinPruneSource()
            model.joinSuppressed = False 
         # pylint: disable-msg=E1103
         model.joinPruneFlagList.append( flag )

      prunedRptGroupData = jpGroupData.prunedRptGroupData
      if jpType == 'prunes' and source in prunedRptGroupData.prunedRptSource:
         flag = JoinPruneFlag()
         flag.flag = 'RPT'

         if not model:
            model = JoinPruneSource()
            model.joinSuppressed = False 
         # pylint: disable-msg=E1103
         model.joinPruneFlagList.append( flag )

      if not model:
         return

      # Convert the address to a prefix for external representation, internally
      # we only support addresses.
      if wcSource and source == wcSource[ 0 ]:
         if wcSource[ 0 ].af == AddressFamily.ipv4:
            source = Tac.newInstance( "Arnet::IpGenPrefix", "0.0.0.0/0" )
         else:
            source = Tac.newInstance( "Arnet::IpGenPrefix", "::/0" )
      else:
         if source.af == AddressFamily.ipv4:
            source = Arnet.IpGenPrefix( source.stringValue + "/32" )
         else:
            source = Arnet.IpGenPrefix( source.stringValue + "/128" )

      yield source, model

class JoinPruneGroup( Model ):

   joinedSources = GeneratorDict( help="Map of sources for this group for which "
                          " periodic joins are being sent", keyType=IpGenericPrefix,
                          valueType=JoinPruneSource )
   prunedSources = GeneratorDict( help="Map of sources for this group for which "
                          " periodic prunes are being sent", keyType=IpGenericPrefix,
                           valueType=JoinPruneSource )
   mode = Enum( help="PIM mode",
                values=( 'Sparse',
                         'Bidir',
                         'Sparse/Bidir',
                         'None' ), optional=True )
   def initialize( self, jp, source):
      if jp.groupFlags & 0x80 == 0x80:
         self.mode = 'Bidir'
      else:
         self.mode = 'Sparse'
      self.joinedSources = joinPruneSourceGen(jp, source, 'joins')
      self.prunedSources = joinPruneSourceGen(jp, source, 'prunes')

   def render( self ):
      indent = 4
      iPrint( indent, "Joins:" )

      atLeastOneJoin = False
      for j, join in self.joinedSources:

         if not atLeastOneJoin:
            atLeastOneJoin = True

         flags = ""
         for flagElement in join.joinPruneFlagList[ : -1 ]:
            flags += flagElement.flag + ","
         flags += join.joinPruneFlagList[ -1 ].flag

         iPrint( indent + 2, f"{j} {flags}" )

      if not atLeastOneJoin:
         iPrint( indent+2, "No joins included" )

      iPrint( indent, "Prunes:" )

      atLeastOnePrune = False
      for p, prune in self.prunedSources:

         if not atLeastOnePrune:
            atLeastOnePrune = True

         flags = ""
         for flagElement in prune.joinPruneFlagList[ : -1 ]:
            flags += flagElement.flag + ","
         flags += prune.joinPruneFlagList[ -1 ].flag

         iPrint( indent + 2, f"{p} {flags}" )

      if not atLeastOnePrune:
         iPrint( indent+2, "No prunes included" )

def joinPruneGroupGen( jps, source, group ):
   
   groups = set()
   if group:
      # Show only this group
      groups.add( group )
   else:
      # Show all groups
      for jp in jps:
         groups |= set( jp.joinPruneGroupData.keys() )

   for g in groups:
      for jp in jps:
         jpGroupData = jp.joinPruneGroupData.get( g )

         if not jpGroupData:
            continue

         if source:
            # A source has been specified, don't return this group if no joins/ 
            # prunes are being sent for this (S,G). 
            if source.isAddrZero:
               if not jpGroupData.wcSourceData:
                  continue
            else:
               if not jpGroupData.containsSource( source ):
                  continue

         model = JoinPruneGroup()
         model.initialize( jpGroupData, source )

         yield g, model

class JoinPruneNeighbor( Model ):

   groups = GeneratorDict( help="Map of groups for which periodic joins/prunes are "
                           "being sent to this neighbor", keyType=IpGenericAddress,
                           valueType=JoinPruneGroup )
   interface = Interface( help="Interface via which joins/prunes are being sent to "
                          "this neighbor" )
   interfaceAddress = IpGenericAddress( help="The join/prune interface's address" )

   def initialize( self, neighbor, intf, jp, source, group ):
      self.interface = intf.intfId
      self.interfaceAddress = intf.address
      self.groups = joinPruneGroupGen( jp, source, group )

def upstreamJoinPruneNeighborsGen( pimStatus, joinPruneStatus,
                                   source=None, group=None, neighbor=None ):
   if pimStatus is None:
      return

   showNeighbors = set()
   for ( intf, pimIntf ) in pimStatus.pimIntf.items():
      if neighbor:
         if neighbor in pimIntf.activeNeighbor:
            showNeighbors.add( ( intf, neighbor ) )
      else:
         showNeighbors |= { ( intf, n ) for n in pimIntf.activeNeighbor }

   # Only include neighbor if there is an entry associated with it.
   for ( i, n ) in sorted( showNeighbors ): # pylint: disable=too-many-nested-blocks
      jp = None
      status = None
      jps = []
      pimIntf = pimStatus.pimIntf.get( i )
      if not pimIntf:
         continue
      pimNeighbor = pimIntf.activeNeighbor.get( n )
      if not pimNeighbor:
         continue
      for status in joinPruneStatus:
         if status is None:
            continue

         jpIntf = status.intfStatus.get( i )
         if not jpIntf:
            continue

         jp = jpIntf.upJoinPrune.get( n )
         if jp is None or len( jp.joinPruneGroupData ) == 0:
            continue

         if group is not None:
            g = jp.joinPruneGroupData.get( group )
            if g is None:
               continue
            groups = iter( [ g ] )
         else:
            groups = jp.joinPruneGroupData.values()

         if source is not None:
            exists = False
            for g in groups:
               if source.isAddrZero:
                  if g.wcSourceData:
                     exists = True
                     break
               elif g.containsSource( source ):
                  exists = True
                  break
            if not exists:
               continue
         jps.append( jp )

      if len( jps ) == 0:
         continue

      model = JoinPruneNeighbor()
      model.initialize( pimNeighbor, pimIntf, jps, source, group )
      yield n, model

class PimUpstreamJoins( Model ):
   neighbors = GeneratorDict( help="Map of upstream neighbors to which "
                              "periodic join/prunes are being sent",
                              keyType=IpGenericAddress,
                              valueType=JoinPruneNeighbor)
   def initialize( self, pimStatus, joinPruneStatus,
                   source=None, group=None, neighbor=None ):
      self.neighbors = upstreamJoinPruneNeighborsGen( pimStatus,
            joinPruneStatus, source, group, neighbor)

   def render( self ):
      for n, neighbor in self.neighbors:
         indent = 0

         iPrint( indent, "Neighbor address: %s" % n )
         indent += 1
         iPrint( indent, "Via interface: %s (%s)" %
                 ( neighbor.interface.stringValue,
                   neighbor.interfaceAddress ) )
         indent += 1

         for g, group in neighbor.groups:
            if group.mode == 'Bidir':
               iPrint( indent, "Group: %s (bidirectional)" % g )
            else:
               iPrint( indent, "Group: %s" % g)
            group.render()
