# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from Ark import utcTimeRelativeToNowStr
from Arnet import sortIntf
from CliModel import (
                       Bool,
                       Dict,
                       Int,
                       List,
                       Model,
                       Str,
                       Submodel,
                       Float
                     )
from IntfModels import Interface
from ArnetModel import IpGenericAddress, IpGenericPrefix, Ip6Address
from TableOutput import createTable, Format
import Tac
import Vlan
from IpLibConsts import DEFAULT_VRF
from natsort import natsorted
import Toggles.DhcpLibToggleLib as DhcpLibToggleLib

def printState( condition ):
   return "enabled" if condition else "disabled"

class Helper( Model ):
   serverAddr = IpGenericAddress( help="Helper address", optional=True )
   serverHostname = Str( help="Helper hostname", optional=True )
   vrf = Str( help="Server VRF" )
   srcAddr = IpGenericAddress( help="Source address", optional=True )
   srcIntf = Interface( help="Source interface", optional=True )
   linkAddr = Ip6Address( help="Link address", optional=True )

   def render( self ):
      toPrint = []
      if self.serverAddr:
         toPrint.append( str( self.serverAddr ) )
      else:
         toPrint.append( self.serverHostname )

      if self.vrf != DEFAULT_VRF:
         toPrint += [ 'vrf', self.vrf ]
      
      if self.linkAddr:
         toPrint += [ 'link address', str( self.linkAddr ) ]
      
      print( " ".join( toPrint ) )

   def renderVssServerList( self ):
      serverAddr = str( self.serverAddr )
      if self.vrf != DEFAULT_VRF:
         serverAddr += ' vrf ' + self.vrf
      return serverAddr
      
def printHelper( helper, helperCount ):
   tab = " " * len( "  DHCPvX servers:" )
   if helperCount != 0:
      print( tab, end=' ' )
   helper.render()
   return helperCount + 1

def getAllHelpersAsLists( helpers ):
   helperListIpv4 = []
   helperListIpv6 = []
   for helper in sorted( helpers, key=compareHelper ):
      ipAddr = helper.serverAddr
      # get DHCPv4 servers
      if ( ( ipAddr and ipAddr.af == 'ipv4' ) or helper.serverHostname ):
         helperListIpv4.append( helper )
      else:
         # get DHCPv6 servers
         helperListIpv6.append( helper )

   return ( helperListIpv4, helperListIpv6 )

# We sort the output based on serverAddress / serverHostname
# respectively, so the order of CLI output does not depend on
# iteration order of a HashTable.
def compareHelper( helper ):
   # Order the DHCP servers by VRF, then by address or
   # hostname, displaying DHCP servers with an address before
   # ones with a hostname.
   vrf = helper.vrf or ''
   hostname = helper.serverHostname or ''
   serverAddr = helper.serverAddr or ''
   return ( vrf, hostname, serverAddr )

class ConfiguredInterface( Model ):
   __revision__ = 2

   circuitId = Str( help="Circuit ID enabled interface", default='' )
   remoteId = Str( help="Remote ID enabled interface", default='' )
   allSubnets = Bool( help="DHCP all-subnets relaying enabled interface",
                      default=False )
   allSubnetsV6 = Bool( help="DHCPv6 all-subnets relaying enabled interface",
                        default=False )
   helpers = List( valueType=Helper,
                   help="Helper addresses/hostnames for an interface",
                   optional=True )
   usingGlobalModeServersV4 = Bool( help="Using DHCPv4 Relay mode configurations",
                                    default=False )
   usingGlobalModeServersV6 = Bool( help="Using DHCPv6 Relay mode configurations",
                                    default=False )
   disabledV4 = Bool( help="DHCPv4 relay is disabled", default=False )
   disabledV6 = Bool( help="DHCPv6 relay is disabled", default=False )

   def degrade( self, dictRepr, revision ):
      # Given the python dictionary representation of the model instance (for
      # the current dictionary revision) and the desired target revision. Return
      # a massaged dictionary that matches the format that corresponds to the
      # requested revision.
      if revision == 1:
         dictRepr[ 'smartRelay' ] = dictRepr.pop( 'allSubnets' )
      return dictRepr

   def render( self ):
      if self.circuitId:
         print( "  Option 82 Circuit ID:", self.circuitId )
      if self.remoteId:
         print( "  Option 82 Remote ID:", self.remoteId )
      print( "  DHCP all subnet relaying is", printState( self.allSubnets ) )
      print( "  DHCPv6 all subnet relaying is", printState( self.allSubnetsV6 ) )
      
      if self.usingGlobalModeServersV4 and not self.disabledV4:
         print( "  Using default DHCPv4 servers" )

      if self.usingGlobalModeServersV6 and not self.disabledV6:
         print( "  Using default DHCPv6 servers" )

      if self.helpers:
         helperListIpv4, helperListIpv6 = getAllHelpersAsLists( self.helpers )
         if ( helperListIpv4 and 
              not ( self.usingGlobalModeServersV4 or self.disabledV4 ) ):
            helperCount = 0
            print( "  DHCPv4 servers:", end=' ' )
            for helper in helperListIpv4:
               helperCount = printHelper( helper, helperCount )

         if ( helperListIpv6 and
              not ( self.usingGlobalModeServersV6 or self.disabledV6 ) ):
            helperCount = 0 
            print( "  DHCPv6 servers:", end=' ' )
            for helper in helperListIpv6:
               helperCount = printHelper( helper, helperCount )
   
      if self.disabledV4:
         print( "  DHCPv4 relay is disabled" )

      if self.disabledV6:
         print( "  DHCPv6 relay is disabled" )

# Model to display dhcp relay global mode configurations
class DefaultServersShowModel( Model ):
   helpersV4 = List( valueType=Helper,
                     help="DHCP Relay mode default IPv4 addresses/hostnames" )
   helpersV6 = List( valueType=Helper,
                     help="DHCP Relay mode default IPv6 addresses" )

   def render( self ):
      # servers
      if self.helpersV4 or self.helpersV6:
         print( "Default L3 interface DHCP servers:" )
         # IPV4 addresses
         if self.helpersV4:
            helperCount = 0
            print( "  DHCPv4 servers:", end=' ' )
            for helper in self.helpersV4:
               helperCount = printHelper( helper, helperCount )

         # IPv6 addresses
         if self.helpersV6:
            helperCount = 0
            print( "  DHCPv6 servers:", end=' ' )
            for helper in self.helpersV6:
               helperCount = printHelper( helper, helperCount )

# Model to display servers not supporting vss subOption
class VssUnsupportedServersModel( Model ):
   vssHelpersV4 = List( valueType=Helper,
         help="IPv4 DHCP Servers not supporting VSS sub-option" )
   vssHelpersV6 = List( valueType=Helper,
         help="IPv6 DHCP Servers not supporting VSS sub-option" )
   
   def render( self ):
      if self.vssHelpersV4:
         print( "DHCP VSS control sub-option (152) unsupported servers:", end=" " )
         print( ", ".join( helper.renderVssServerList()
                           for helper in self.vssHelpersV4 ) )
      if self.vssHelpersV6:
         print( "DHCPv6 VSS sub-option (68) unsupported servers:", end=" " )
         print( ", ".join( helper.renderVssServerList()
                           for helper in self.vssHelpersV6 ) )

# Model to keep track of enabled suboptions under information option 82
class Option82SubOptionsModel( Model ):
   circuitIdOpt = Bool( help="DHCP relay circuit-id suboption (1) under information "
                             "option (82) is enabled",
                        default=False )
   remoteIdOpt = Bool( help="DHCP relay remote-id suboption (2) under information "
                            "option (82) is enabled",
                          default=False )
   vendorSpecificSubOpt = Bool( help="DHCP relay vendor-specific suboption (9) under"
                                     " information option (82) is enabled",
                                default=False )

   def render( self ):
      finalStr = "Enabled information option suboptions: "
      subOptList = []
      if self.circuitIdOpt:
         subOptList.append( "circuit-id (1)" )
      if self.remoteIdOpt:
         subOptList.append( "remote-id (2)" )
      if self.vendorSpecificSubOpt:
         subOptList.append( "vendor-option (9)" )
      
      finalStr += ", ".join( subOptList )
      print( finalStr )

class IpDhcpRelay( Model ):
   __revision__ = 2
   
   activeState = Bool( help="DHCP relay is active" )
   alwaysOn = Bool( help="DHCP relay alwaysOn state", default=False )
   option82 = Bool( help="DHCP relay information option (82) is enabled", \
         default=False )
   option82SubOpts = Submodel( valueType=Option82SubOptionsModel,
                               help="Enabled suboptions under DHCP relay information"
                                    " option ( 82 )",
                               default=Option82SubOptionsModel(), optional=True )
   linkLayerAddrOpt = \
         Bool( help="DHCPv6 relay link-layer address option (79) is enabled", \
         default=False )
   remoteIdEncodingFormat = \
         Str( help="DHCPv6 relay remote ID option (37) encoding format", \
         default='%m:%i' )
   allSubnets = Bool( help="DHCP all-subnets relaying is enabled", default=False )
   allSubnetsV6 = Bool( help="DHCPv6 all-subnets relaying is enabled",
                        default=False )
   tunnelReqDisable = Bool( help="DHCP tunnel requests are disabled", default=False )
   mlagPeerLinkReqDisable = Bool(
         help="DHCP requests received over MLAG peer link are suppressed",
         default=False )
   reqFloodSuppressionVlanConfig = List( valueType=int,
         help="DHCP client request flooding suppressed VLANs configured" )
   reqFloodSuppressionVlanOperational = List( valueType=int,
         help="DHCP client request flooding suppressed VLANs operational" )
   dhcpReqFloodRestrictDisabledDueToSnooping = \
         Bool( help="DHCP client request flooding suppression disbled due to "
               "DHCP snooping configuration", default=False )
   vssSubOptUnsupportedServers = Submodel( valueType=VssUnsupportedServersModel,
         help="DHCP VSS control sub-option (152) and DHCPv6 VSS option (68)"
         " unsupported servers", optional=True )
   configuredInterfaces = Dict( keyType=Interface, valueType=ConfiguredInterface,
                           help="A mapping of interface to its DHCP configuration" )
   defaultServers = Submodel( valueType=DefaultServersShowModel,
                              help="DHCP Relay mode default addresses/hostnames",
                              optional=True )
   vssControlDisable = Bool( help="DHCP VSS control sub-option (152) is disabled",
                             default=False )

   def degrade( self, dictRepr, revision ):
      # Given the python dictionary representation of the model instance (for
      # the current dictionary revision) and the desired target revision. Return
      # a massaged dictionary that matches the format that corresponds to the
      # requested revision.
      if revision == 1:
         dictRepr[ 'smartRelay' ] = dictRepr.pop( 'allSubnets' )
      return dictRepr

   def render( self ):
      print( "DHCP relay is", "active" if self.activeState else "not active" )
      if self.activeState:
         if self.alwaysOn:
            print( "DHCP relay always-on mode enabled" )
         print( "DHCP relay information option (82) is",\
               printState( self.option82 ) )
         if self.option82 and \
            DhcpLibToggleLib.toggleDhcpRelayUnnumberedIntfEnabled():
            self.option82SubOpts.render()
         if self.reqFloodSuppressionVlanConfig:
            print( "DHCP/DHCPv6 client requests flooding suppression configured for "
                   "the following VLANs:")
            outVlanStr = Vlan.vlanSetToCanonicalString( \
                           sorted( self.reqFloodSuppressionVlanConfig ) )
            assert outVlanStr
            print( '%s' % ( outVlanStr ) )
            print( "DHCP/DHCPv6 client requests flooding suppression operational "
                   "for the following VLANs:")
            outVlanStr = Vlan.vlanSetToCanonicalString( \
                           sorted( self.reqFloodSuppressionVlanOperational ) )
            print( '%s' % outVlanStr if outVlanStr \
                   else 'none (DHCP/DHCPv6 snooping is configured)' \
                   if self.dhcpReqFloodRestrictDisabledDueToSnooping else 'none' )
         print( "DHCPv6 relay link-layer address option (79) is",
                printState( self.linkLayerAddrOpt ) )
         fmt = self.remoteIdEncodingFormat
         if fmt in ( '%m:%i', '%m:%p' ):
            fmt = 'ID' if fmt == '%m:%i' else 'name'
            print( "DHCPv6 relay remote ID option (37) encoding format:"
                   " MAC address:interface", fmt )
         print( "DHCP all subnet relaying is", printState( self.allSubnets ) )
         print( "DHCPv6 all subnet relaying is", printState( self.allSubnetsV6 ) )
         if self.tunnelReqDisable:
            print( "DHCP tunnel requests are disabled" )
         if self.mlagPeerLinkReqDisable:
            print( "MLAG peer-link requests are disabled" )
         if self.vssControlDisable:
            print( "DHCP relay virtual subnet selection control suboption (152) "
                   "addition is disabled" )
 
         # print dhcpv4 servers which dont support vss control subOption(152)
         # print dhcpv6 servers which dont support vss Option(68)
         if self.vssSubOptUnsupportedServers:
            self.vssSubOptUnsupportedServers.render()

      # print dhcp relay mode servers
      if self.defaultServers:
         self.defaultServers.render()

      for intfName in sortIntf( self.configuredInterfaces ):
         print( "Interface:", intfName )
         self.configuredInterfaces[ intfName ].render()

class Dhcp6InstalledRoute( Model ):
   clientIpAddress = Ip6Address( help="DHCP client's IP address" )
   interface = Interface( help="DHCP client's interface" )
   expiryTime = Float( help="Expiry time of prefix in UTC" )


class Dhcp6InstalledRoutesPerVrf( Model ):
   routes = Dict( keyType=IpGenericPrefix, valueType=Dhcp6InstalledRoute,
                  help="DHCP Relay installed routes keyed by IPv6 prefix" )

class Dhcp6InstalledRoutes ( Model ):
   _isVrfSpecified = Bool( help="User provided VRF as argument to CLI" )
   vrfs = Dict( keyType=str, valueType=Dhcp6InstalledRoutesPerVrf,
                help="DHCP Relay prefixes keyed by VRF" )

   def render( self ):
      fields = [ 'Prefix', 'Client Address', 'Interface', \
                 'Remaining Lifetime (seconds)']

      if not self._isVrfSpecified:
         fields.append( 'VRF' )
      
      currTime = Tac.utcNow()
      outputTableVrf = createTable( fields, tableWidth=100 ) 
      for vrf in sorted( self.vrfs ):
         routes = self.vrfs[ vrf ].routes
         for prefix in natsorted( routes ):
            route = routes[ prefix ]
            expiryTime = route.expiryTime
            remainingTime = int( expiryTime - currTime )
            clientInterface = route.interface.stringValue
            line = [ prefix, route.clientIpAddress, clientInterface,
                     remainingTime ]
            if not self._isVrfSpecified:
               line.append( vrf )
            outputTableVrf.newRow( *line )
   
      f1 = Format( justify='left' )
      f2 = Format( justify='right', maxWidth=18, wrap=True )
      f1.padLimitIs( True )
      f2.padLimitIs( True )
      if not self._isVrfSpecified:
         outputTableVrf.formatColumns( f1, f1, f1, f2, f1 )
      else:
         outputTableVrf.formatColumns( f1, f1, f1, f2 )
      print( outputTableVrf.output() )

   def setVrfSpecified ( self, vrfSpecified ):
      self._isVrfSpecified = vrfSpecified

class Counters( Model ):
   received = Int( default=0,
                   help="Received packets" )
   forwarded = Int( default=0,
                    help="Forwarded packets" )
   dropped = Int( default=0,
                  help="Dropped packets" )

class GlobalCounters( Model ):
   allRequests = Submodel( valueType=Counters,
                          help="Global counter for requests" )
   allResponses = Submodel( valueType=Counters,
                           help="Global counter for responses" )
   lastResetTime = Float( default=0.0,
                          help="Last reset time for global counters" )

class InterfaceCounter( Model ):
   requests = Submodel( valueType=Counters,
                       help="Interface counter requests" )
   replies = Submodel( valueType=Counters,
                     help="Interface counter for replies" )
   lastResetTime = Float( help="Last reset time for interface counter" )

class DhcpRelayCounterModel( Model ):
   globalCounters = Submodel( valueType=GlobalCounters,
                             help="DHCP relay global counters" )
   interfaceCounters = Dict( keyType=Interface, valueType=InterfaceCounter,
                             help="A mapping of interface to its counters",
                             optional=True )

   def render( self ):
      f1 = Format( justify="left" )
      f1.padLimitIs( True )

      globalCounterReq = self.globalCounters.allRequests
      globalCounterResp = self.globalCounters.allResponses
      outputHeader = [ "Globals", \
            ( "DHCP Packets", "ch", ( "Rcvd", "Fwdd", "Drop" ) ),
            "Last Cleared" ]
      outputTable = createTable( outputHeader )
      globalOutput = [ [ "All Req", globalCounterReq.received,
                                    globalCounterReq.forwarded,
                                    globalCounterReq.dropped,
                                    utcTimeRelativeToNowStr(
                                       self.globalCounters.lastResetTime ) ],
                       [ "All Resp", globalCounterResp.received,
                                     globalCounterResp.forwarded,
                                     globalCounterResp.dropped,
                                     "" ], ]
      for row in globalOutput:
         outputTable.newRow( *row )
      outputTable.formatColumns( f1, f1, f1, f1, f1 )
      print( "\n" + outputTable.output() )

      if not self.interfaceCounters:
         return

      def prepareIntfCounterOutput( intfname, intfCounterVal ):
         output = [ intfname,
                    intfCounterVal.requests.received,
                    intfCounterVal.requests.forwarded,
                    intfCounterVal.requests.dropped,
                    intfCounterVal.replies.received,
                    intfCounterVal.replies.forwarded,
                    intfCounterVal.replies.dropped,
                    utcTimeRelativeToNowStr(
                       intfCounterVal.lastResetTime ) ]
         return output

      interfaceHeaders = [ "Interface",
            ( "DHCP Request Packets", "ch", ( "Rcvd", "Fwdd", "Drop" ) ),
            ( "DHCP Reply Packets", "ch", ( "Rcvd", "Fwdd", "Drop" ) ),
            "Last Cleared" ]
      outputTable = createTable( interfaceHeaders )
      interfaceOutput = []
      for intfname in sortIntf( self.interfaceCounters ):
         intfCounterVal = self.interfaceCounters[ intfname ]
         interfaceOutput.append(
               prepareIntfCounterOutput( intfname, intfCounterVal ) )
      for row in interfaceOutput:
         outputTable.newRow( *row )
      outputTable.formatColumns( f1, f1, f1, f1, f1, f1, f1, f1 )
      print( outputTable.output() )
