# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BothTrace
import Tac
import Tracing
from CliModel import (
   Bool,
   Enum,
   GeneratorSubmodel,
   Int,
   List,
   Model,
   Str,
   Submodel,
)
from MplsPingClientLib import (
   getPingModel,
)
from MplsTracerouteClientLib import (
   getTracerouteModel,
)
from ClientCommonLib import (
   getThreadLocalData,
   LspPingDSTypes,
   LspPingTypes,
   setThreadLocalData,
)

dsTypeEnum = Tac.Type( 'LspPing::LspPingDownstreamMappingType' ).attributes
retCodeEnum = Tac.Type( 'LspPing::LspPingReturnCode' ).attributes
addrTypeEnum = Tac.Type( 'LspPing::LspPingAddrType' ).attributes
pwPingCcType = Tac.Type( 'Pseudowire::PwPingCcType' ).attributes

#--------------------------------
# BothTrace Short hand variables
#--------------------------------
__defaultTraceHandle__ = Tracing.Handle( "MplsUtilModel" )
bt8 = BothTrace.trace8

class GenericModel( Model ):
   prefix = Str( help="Destination" )

class SrModel( Model ):
   prefix = Str( help="Destination" )
   algorithm = Str( optional=True, help='Algorithm' )
   igpProtocolType = Str( optional=True, help="Segment Routing IGP Protocol" )

class NhgModel( Model ):
   __public__ = False
   prefix = Str( help="Destination", optional=True )
   nhgName = Str( help='Nexthop group name' )
   tunnelIndex = Int( help='Nexthop tunnel index',
                      optional=True )

class PwLdpModel( Model ):
   pwLdpName = Str( help='Pseudowire LDP name' )
   pwNeighbor = Str( help='Neighbor IP address' )
   pwId = Str( help='Pseudowire ID' )
   pwLabel = Int( help='Pseudowire VC label' )
   pwCCType = Enum( pwPingCcType, help='Pseudowire CC types' )

class SrteModel( Model ):
   __public__ = False
   endpoint = Str( help='SR-TE endpoint address' )
   color = Int( help='SR-TE color' )

class RsvpModel( Model ):
   __revision__ = 2
   # prefix is only required for LSP ping statistic summary
   prefix = Str( help="Destination", optional=True )
   lspId = Int( help='LSP Id', optional=True )
   session = Str( help="Rsvp session name or session Id", optional=True )
   tunnel = Str( help="Rsvp tunnel name", optional=True )
   subTunnelId = Int( help="Rsvp subtunnel Id", optional=True )

class BgpLuModel( Model ):
   prefix = Str( help="Destination" )

class VpnModel( Model ):
   __revision__ = 2
   prefix = Str( help="Destination" )
   rd = Str( help=' Route Distinguisher' )

class IntfAndLabelStackInfo( Model ):
   addrType = Enum( addrTypeEnum, help='Address type' )
   intfAddr = Str( help='Interface address' )
   intfIndex = Int( help='Interface index' )
   addr = Str( help='IP address' )
   labelStack = List( valueType=str, help='Received label stack' )

class FecTypeList( Model ):
   fecTypeList = List( valueType=str, help='FEC TLV type list' )

class DownstreamInfo( Model ):
   addrType = Enum( addrTypeEnum, help='Address type' )
   dsIntfAddr = Str( help='Downstream interface address' )
   dsIntfIndex = Int( help='Downstream interface index' )
   dsIpAddr = Str( help='Downstream IP address' )
   labelStack = List( valueType=str, help='Downstream label stack' )
   retCode = Enum( retCodeEnum, help='Return code' )
   dsType = Enum( values=dsTypeEnum, help='Downstream information type' )
   popFecTypeList = Submodel( valueType=FecTypeList, optional=True,
                      help='List of FEC types to pop' )
   pushFecTypeList = Submodel( valueType=FecTypeList, optional=True,
                      help='List of FEC types to push' )

class HopPktInfo( Model ):
   replyHost = Str( help='Reply address' )
   hopMtu = Int( help='Hop MTU' )
   roundTrip = Int( help='Roundtrip time (ms)' )
   roundTripUs = Int( help='RoundTrip time (us)' )
   oneWayDelay = Int( help='One way delay (us)' )
   retCode = Enum( retCodeEnum, help='Return code' )
   ttl = Int( help='Time to live' )
   dsInfo = List( valueType=DownstreamInfo, help='Downstream infos' )
   intfAndLabelStackInfo = Submodel( valueType=IntfAndLabelStackInfo,
                                     help="Intf and labelStack info", optional=True )

class MplsPath( Model ):
   hopPkt = List( valueType=HopPktInfo, help="HopPktInfos" )
   pathResult = Enum( values=( '  Successfully explored path.\n',
                               '  Error: Could not probe further...\n' ),
                      help='Result of path exploration' )

class MplsTracerouteStatistics( Model ):
   destination = Str( help='Reply host address' )
   roundTripTimeMinUs = Int( help='Minimum round trip time in microseconds' )
   roundTripTimeMaxUs = Int( help='Maximum round trip time in microseconds' )
   roundTripTimeAvgUs = Int( help='Average round trip time in microseconds' )
   packetsReceived = Int( help='Packets received from destination' )

class MplsTracerouteStatisticsSummary( Model ):
   packetsTransmitted = Int( help='Packets transmitted' )
   packetsReceived = Int( help='Packets received' )
   packetLoss = Int( help='Percentage of packets lost' )
   replyTime = Int( help='Reply timestamp time' )
   mplsStatistics = List( valueType=MplsTracerouteStatistics,
                          help="MPLS traceroute statistics details", optional=True )

class MplsVia( Model ):
   nextHopIp = Str( help='NextHop IP' )
   labelStack = List( valueType=int, help='LabelStack' )
   interface = Str( optional=True, help='Interface name' )
   path = List( valueType=MplsPath, help='Via paths' )
   stats = Submodel( valueType=MplsTracerouteStatisticsSummary,
                     help="MPLS traceroute statistics summary", optional=True )

class ErrInfo( Model ):
   error = Str( help='Error encountered when cmd issued' )

class MplsTracerouteHdr( Model ):
   __revision__ = 2
   protocol = Enum( values=LspPingTypes, help='Protocol' )
   ldpModel = Submodel( optional=True, valueType=GenericModel,
                        help="MPLS LDP traceroute model" )
   srModel = Submodel( optional=True, valueType=SrModel,
                       help="MPLS SR traceroute model" )
   rsvpModel = Submodel( optional=True, valueType=RsvpModel,
                        help="MPLS RSVP traceroute model" )
   dsInfoType = Enum( values=LspPingDSTypes, help='Downstream information type' )
   errInfo = List( valueType=ErrInfo, help='Errors' )
   viaModel = List( valueType=MplsVia, help="Vias" )
   statsModel = List( optional=True, valueType=MplsVia,
                      help="Statistics summary per Via" )

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         if dictRepr.get( 'ldpModel' ):
            dictRepr[ 'prefix' ] = dictRepr[ 'ldpModel' ][ 'prefix' ]
            del dictRepr[ 'ldpModel' ]
         else:
            dictRepr[ 'prefix' ] = None
         return dictRepr
      return dictRepr

# Parent model for Multipath/Non-Multipath Traceroute for all protocols
# with submodels inside.
class MplsTraceroute( Model ):
   __revision__ = 2
   traceroute = GeneratorSubmodel( valueType=MplsTracerouteHdr,
                                   help="Traceroute request" )
   def render( self ):
      bt8( "Render model for Traceroute" )
      while True:
         # cmdBreak is only set in getModel function
         # if socket listening times out and subprocess is not running
         cmdBreak = getThreadLocalData( 'cmdBreak' )
         if cmdBreak:
            break
         setThreadLocalData( 'render', True )
         getTracerouteModel()
         setThreadLocalData( 'cache', None )
      setThreadLocalData( 'cmdBreak', False )

class MplsPingReply( Model ):
   replyHost = Str( help='Reply host address' )
   sequence = Int( help='Sequence number' )
   roundTripTime = Int( help='RoundTrip time for packet in milliseconds' )
   roundTripTimeUs = Int( help='Round Trip time for packet in microseconds' )
   oneWayDelay = Int( help='One way delay time for packet in microseconds' )
   retCode = Enum( retCodeEnum, help='Return code' )
   errorTlvMap = List( optional=True, valueType=int, help='Error Tlv Map' )
   errorSubTlvMap = List( optional=True, valueType=int, help='Error sub Tlv Map' )

class MplsPingStatistics( Model ):
   destination = Str( help='Reply host address' )
   roundTripTimeMin = Int( help='Minimum round trip time in milliseconds' )
   roundTripTimeMax = Int( help='Maximum round trip time in milliseconds' )
   roundTripTimeAvg = Int( help='Average round trip time in milliseconds' )
   roundTripTimeMinUs = Int( help='Minimum round trip time in microseconds' )
   roundTripTimeMaxUs = Int( help='Maximum round trip time in microseconds' )
   roundTripTimeAvgUs = Int( help='Average round trip time in microseconds' )
   oneWayDelayMin = Int( help='Minimum one way delay in microseconds' )
   oneWayDelayMax = Int( help='Maximum one way delay in microseconds' )
   oneWayDelayAvg = Int( help='Average one way delay in microseconds' )
   packetsReceived = Int( help='Packets received from destination' )

class MplsPingStatisticsSummary( Model ):
   packetsTransmitted = Int( help='Packets transmitted' )
   packetsReceived = Int( help='Packets received' )
   packetLoss = Int( help='Percentage of packets lost' )
   replyTime = Int( help='Reply timestamp time' )
   mplsStatistics = List( valueType=MplsPingStatistics,
                          help="MPLS ping statistics details", optional=True )

class MplsPingVia( Model ):
   nhgTunnelEntryIndex = Int( help='Nexthop tunnel entry index',
                              optional=True )
   lspId = Int( help="RSVP tunnel LSP ID", optional=True )
   subTunnelId = Int( help="RSVP sub tunnel ID", optional=True )
   resolved = Bool( help='Resolution status of tunnel/via',
                    default=True )
   nextHopIp = Str( optional=True, help='NextHop IP' )
   interface = Str( optional=True, help='Interface name' )
   labelStack = List( optional=True, valueType=int, help='Label stack' )
   pingReply = Submodel( valueType=MplsPingReply,
                         help="MPLS reply packet details", optional=True )
   pingStats = Submodel( valueType=MplsPingStatisticsSummary,
                         help="MPLS ping statistics", optional=True )

class MplsPingHdr( Model ):
   __revision__ = 2
   protocol = Enum( values=LspPingTypes, help='Protocol' )
   timeout = Int( help='Timeout for ping' )
   interval = Int( help='Ping intervals' )
   # To uncomment when we support the cmds for these models.
   # Also make each subModel public then by removing __public__ = False
   genericModel = Submodel( optional=True, valueType=GenericModel,
                            help="MPLS generic ping model" )
   # nhgModel = Submodel( optional=True, valueType=NhgModel,
   #                     help="MPLS NHG model" )
   pwLdpModel = Submodel( optional=True, valueType=PwLdpModel,
                         help="MPLS PW LDP ping model" )
   srModel = Submodel( optional=True, valueType=SrModel,
                       help="MPLS SR ping model" )
   # srTeModel = Submodel( optional=True, valueType=SrteModel,
   #                      help="MPLS SR-TE ping model" )
   # rsvpModel = Submodel( optional=True, valueType=RsvpModel,
   #                      help="MPLS RSVP ping model" )
   bgpLuModel = Submodel( optional=True, valueType=BgpLuModel,
                          help="MPLS BGP LU ping model" )
   rsvpModel = Submodel( optional=True, valueType=RsvpModel,
                        help="MPLS RSVP ping model" )
   vpnModel = Submodel( optional=True, valueType=VpnModel,
                        help="MPLS VPN ping model" )
   viaModel = List( valueType=MplsPingVia, help="MPLS ping vias" )
   statsModel = List( valueType=MplsPingVia,
                      help="MPLS ping statistics" )
   errInfo = List( optional=True, valueType=ErrInfo, help='Errors' )

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         if dictRepr.get( 'vpnModel' ):
            dictRepr[ 'vpnModel' ][ 'viaModel' ] = dictRepr[ 'viaModel' ]
            del dictRepr[ 'viaModel' ]
            dictRepr[ 'vpnModel' ][ 'statsModel' ] = dictRepr[ 'statsModel' ]
            del dictRepr[ 'statsModel' ]
         else:
            del dictRepr[ 'viaModel' ]
            del dictRepr[ 'statsModel' ]
         return dictRepr
      return dictRepr

# Parent model for Ping for all protocols
class MplsPing( Model ):
   __revision__ = 2
   ping = GeneratorSubmodel( valueType=MplsPingHdr, help="Ping request" )
   def render( self ):
      bt8( "Render model for Ping" )
      while True:
         # cmdBreak is only set in getPingModel function
         # if socket listening times out and subprocess is not running
         cmdBreak = getThreadLocalData( 'cmdBreak' )
         if cmdBreak:
            break
         setThreadLocalData( 'render', True )
         getPingModel()
         setThreadLocalData( 'cache', None )
      setThreadLocalData( 'cmdBreak', False )
