#!/usr/bin/env python3
# Copyright (c) 2010-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac, Arnet
import os, re
from SysdbHelperUtils import SysdbPathHelper
from TypeFuture import TacLazyType
import inspect

dhcpStatus = None
sysLoggingConfig = None

# env variables in the following table get added to globals(), disable warning
# pylint: disable=E0602

dhclientEnvVars = [
   'reason' ,
   'interface',
   'new_ip_address',
   'new_subnet_mask',
   'new_domain_name',
   'new_domain_name_servers',
   'new_log_servers',
   'new_domain_search',
   'old_ip_address',
   'alias_ip_address',
   'new_broadcast_address',
   'new_interface_mtu',
   'new_routers',
   'new_static_routes',
   'new_rfc3442_classless_static_routes',
   'new_host_name',
   'new_time_offset',
   'new_tftp_server_name',
   'new_bootfile_name',
   'new_ip6_address',
   'new_ip6_prefixlen',
   'new_dhcp6_name_servers',
   'new_dhcp6_domain_search',
   'new_dhcp6_bootfile_url',
   'new_sztp_redirect',
   'new_ntp_servers' ]

dhclient6EnvVars = [
   'new_ip6_address',
   'new_ip6_prefixlen',
   'new_dhcp6_name_servers',
   'new_dhcp6_domain_search',
   'new_dhcp6_bootfile_url',
   'new_dhcp6_sztp_redirect',
   'new_dhcp6_sntp_servers' ]

getFnName = lambda: inspect.stack()[ 1 ][ 3 ]

logFd = 0
def openLogFile():
   intfTag = ''
   if 'interface' in os.environ:
      intfTag = '-' + os.environ.get( 'interface' )

   # pylint: disable=consider-using-with
   return open( f'/var/log/dhclient-script{intfTag}.log', 'w' )

def t6( msg ):
   # write to script trace file. Separate files are created for each interface and
   # the agent will read and redirect the contents to the trace files.
   logFd.write( msg + "\n" )
   logFd.flush()

def RouteKey( prefix, preference ):
   return Tac.Value( "Routing::RouteKey", prefix=prefix, preference=preference )

def Via( hop, intf ):
   if intf:
      return Tac.Value( "Routing::Via", hop=Arnet.IpGenAddr( hop ), intfId=intf )
   else:
      return Tac.Value( "Routing::Via", hop=Arnet.IpGenAddr( hop ), intfId='' )

def IpAddr( addr ):
   return Tac.Value( "Arnet::IpAddr", addr )

def classBits( ip ):
   bits = 32
   ipBytes = ip.split( '.' )
   if len( ipBytes ) != 4:
      return 0
   ipInt = ( int( ipBytes[ 0 ] ) << 24 ) + ( int( ipBytes[ 1 ] ) << 16 ) + \
       ( int( ipBytes[ 2 ] ) << 8 ) + int( ipBytes[ 3 ] )
   mask = 255
   for _ in range( 0, 4 ):
      if ( ipInt & mask ) == 0:
         bits -= 8
      else:
         return bits
      mask <<= 8
   return bits

def getPrefix( ip, nm ):
   prefix = None
   if ip and nm:
      try:
         output = Tac.run( [ 'ipcalc', '-s', '-p', ip, nm ],
                           stdout=Tac.CAPTURE )
         m = re.match( r'PREFIX=(\d+)', output )
         if m:
            prefix = str( m.group( 1 ) )
      except Tac.SystemCommandError:
         pass
   return prefix

def parseClasslessStaticRoutes( bytesStr ):
   t6( f'{getFnName()}' )
   try:
      _bytes = [ int( b ) for b in bytesStr.split() ]
   except ValueError:
      return []

   routes = []
   while _bytes:
      mask = _bytes.pop( 0 )
      if mask > 24:
         ipBytes = 4
      elif mask > 16:
         ipBytes = 3
      elif mask > 8:
         ipBytes = 2
      else:
         ipBytes = 1

      # not enough bytes in input!
      if len( _bytes ) < ( ipBytes + 4 ):
         break

      if mask > 24:
         destNetwork = ( f'{_bytes.pop( 0 )}.{_bytes.pop( 0 )}.'
                         f'{_bytes.pop( 0 )}.{_bytes.pop( 0 )}' )
      elif mask > 16:
         destNetwork = ( f'{_bytes.pop( 0 )}.{_bytes.pop( 0 )}.'
                         f'{_bytes.pop( 0 )}.0' )
      elif mask > 8:
         destNetwork = f'{_bytes.pop( 0 )}.{_bytes.pop( 0 )}.0.0'
      else:
         destNetwork = f'{_bytes.pop( 0 )}.0.0.0'

      nextHop = ( f'{_bytes.pop( 0 )}.{_bytes.pop( 0 )}.'
                  f'{_bytes.pop( 0 )}.{_bytes.pop( 0 )}' )

      routes.append( ( Arnet.IpGenPrefix( destNetwork + '/' + str( mask ) ),
                       nextHop ) )
   return routes

def getEnvVars( dhclientGenEnvVars, isV4=True ):
   tag = "v4" if isV4 else "v6"
   t6( f"{getFnName()} ({tag}):" )
   for var in dhclientGenEnvVars:
      if var in os.environ:
         value = os.environ.get( var, '' )
         t6( f'   {var}={value}' )
         globals()[ var ] = value
      else:
         globals()[ var ] = None

def getGenEnvVars():
   t6( f'{getFnName()}' )
   getEnvVars( dhclientEnvVars )
   getEnvVars( dhclient6EnvVars, False )

def preinitHandler():
   t6( f'{getFnName()} nothing to do' )
   return 0

def arpcheckHandler():
   t6( '{getFnName()}' )
   if not new_ip_address or not interface:
      return 0

   cmd = f"/usr/sbin/arping -q -f -c 2 -w 3 -D -I {interface} {new_ip_address}"

   t6( f'{getFnName()} running cmd: {cmd}' )
   if os.system( cmd ):
      return 0
   else:
      return 1

netConfigType = TacLazyType( "System::NetConfig" )

def populateNameServers( new_domain_gen_name_servers, intfDhcpGenStatus ):
   t6( f'{getFnName()}' )
   sepRe = re.compile( r'[, ]\s*' )
   nameServers = sepRe.split( new_domain_gen_name_servers )
   for index, nameServer in enumerate( nameServers ):
      # split returns an trailing empty string, which isn't a valid IpGenAddr
      if not nameServer:
         continue
      if index >= netConfigType.maxNameServers:
         break
      intfDhcpGenStatus[ interface ].nameServer[ index ] = Arnet.IpGenAddr(
         nameServer )
   t6( f'{getFnName()} {len(nameServers)} name-server(s) configured' )

def populateSztpServers( new_sztp_servers, intfDhcpGenStatus ):
   t6( f'{getFnName()}' )
   sztpServersUris = new_sztp_servers.strip().split( ' ' )

   # Check for any Bootz Uris present in the URI list
   bootzUrisPresent = False
   for index, sztpServerUri in enumerate( sztpServersUris ):
      if sztpServerUri.startswith( "bootz://" ):
         bootzUrisPresent = True
         break

   index = 0
   for sztpServerUri in sztpServersUris:
      # Bootz(bootz://) URIs take priority over SZTP(https://) URIs; if any found,
      # remove all non-Bootz URIs
      if bootzUrisPresent and not sztpServerUri.startswith( "bootz://" ):
         t6( f"SZTP(bootz) URIs present, ignoring non-bootz URI {sztpServerUri}" )
      else:
         sztpServer = Tac.Value( "ZeroTouch::SztpServer" )
         sztpServer.uri = sztpServerUri
         intfDhcpGenStatus[ interface ].sztpServers[ index ] = sztpServer
         index += 1

   tag = "SZTP(bootz)" if bootzUrisPresent else "SZTP"
   t6( f'{getFnName()} {index} {tag}-server(s) configured' )

def populateNtpServers( new_ntp_servers, intfDhcpGenStatus ):
   t6( f'{getFnName()}' )
   ntpServersIp = new_ntp_servers.strip().split( ' ' )
   intfDhcpGenStatus[ interface ].ntpServers.clear()
   for ip in ntpServersIp:
      ntpServer = Tac.Value( "ZeroTouch::NtpServer" )
      ntpServer.hostIp = ip
      intfDhcpGenStatus[ interface ].ntpServers.enq( ntpServer )
   t6( f'{getFnName()} {len(ntpServersIp)} NTP-server(s) configured' )

def dhconfigBoundHandler():
   t6( f'{getFnName()}' )

   prefix = getPrefix( new_ip_address, new_subnet_mask )
   if not prefix or not new_ip_address:
      return 0

   # ip address and mask
   dhcpStatus.intfDhcpStatus[ interface ].addrWithMask = Arnet.IpGenAddrWithMask(
                                                  new_ip_address + '/' + prefix )
   # Should use Tac.Type( 'ZeroTouch::DhcpAf' ).ipv4 but for BUG744451...
   dhcpStatus.intfDhcpStatus[ interface ].af = "ipv4"

   # interface mtu
   if new_interface_mtu and int( new_interface_mtu ) > 576:
      dhcpStatus.intfDhcpStatus[ interface ].mtu = int( new_interface_mtu )

   # static routes
   if new_static_routes:
      sepRe = re.compile( r'[, ]\s*' )
      staticRoutes = sepRe.split( new_static_routes )
      index = 0
      for target, gateway in zip( staticRoutes[ ::2 ], staticRoutes[ 1::2 ] ):
         prefix = Arnet.IpGenPrefix( f'{target}/{classBits( target )}' )
         dhcpStatus.intfDhcpStatus[ interface ].staticRoute[ index ] = \
             Tac.Value( "ZeroTouch::Route", \
                           key=RouteKey( prefix, 1 ), \
                           via=Via( gateway, None ) )
         index += 1
      t6( f'{getFnName()} {index} static route(s) configured' )

   # classless static routes
   if new_rfc3442_classless_static_routes:
      routes = parseClasslessStaticRoutes( new_rfc3442_classless_static_routes )
      index = 0
      for route in routes:
         ( prefix, gateway ) = route
         dhcpStatus.intfDhcpStatus[ interface ].staticRoute[ index ] = \
             Tac.Value( "ZeroTouch::Route", \
                           key=RouteKey( prefix, 1 ), \
                           via=Via( gateway, None ) )
         index += 1
      t6( f'{getFnName()} {index} classless static route(s) configured' )

   # gateways
   if new_routers:
      sepRe = re.compile( r'[, ]\s*' )
      gateways = sepRe.split( new_routers )
      index = 0
      for gateway in gateways:
         dhcpStatus.intfDhcpStatus[ interface ].gateway[ index ] = \
               Arnet.IpGenAddr( gateway )
         index += 1
         if index >= 3:
            break
      t6( f'{getFnName()} {index} gateway(s) configured' )

   # host name
   if new_host_name:
      dhcpStatus.intfDhcpStatus[ interface ].hostname = new_host_name

   # domain name
   if new_domain_name:
      dhcpStatus.intfDhcpStatus[ interface ].domainName = new_domain_name

   # name servers
   if new_domain_name_servers:
      populateNameServers( new_domain_name_servers, dhcpStatus.intfDhcpStatus )

   # log servers
   if new_log_servers:
      sepRe = re.compile( r'[, ]\s*' )
      logServers = sepRe.split( new_log_servers )

      for ipAddrOrHostname in logServers:
         loggingHostType = Tac.Type( "LogMgr::LoggingHost" )
         port = loggingHostType().portDefault
         ports = {}
         ports[ port ] = port
         hostInfo = Tac.Value( "LogMgr::LoggingHost",
                               ipAddrOrHostname=ipAddrOrHostname,
                               protocol="udp",
                               ports=ports )
         dhcpStatus.intfDhcpStatus[ interface ].loggingHost.addMember( hostInfo )

   # dhcp options 66 and 67
   if new_tftp_server_name:
      dhcpStatus.intfDhcpStatus[ interface ].serverName  = new_tftp_server_name

   if new_bootfile_name:
      dhcpStatus.intfDhcpStatus[ interface ].bootFileName = new_bootfile_name

   # dhcp sztp-redirect option 143
   if new_sztp_redirect:
      populateSztpServers( new_sztp_redirect, dhcpStatus.intfDhcpStatus )

   # dhcp ntp-servers option 42
   if new_ntp_servers:
      populateNtpServers( new_ntp_servers, dhcpStatus.intfDhcpStatus )

   # Note: genId has to be last assigned attr as this triggers the
   # state machine
   dhcpStatus.intfDhcpStatus[ interface ].genId += 1

   return 0

def doDAD():
   pass

def getDhcp6EnvParams( ):
   t6( f'{getFnName()}' )

   # domain name
   if new_dhcp6_domain_search:
      dhcpStatus.intfDhcp6Status[ interface ].domainName = new_dhcp6_domain_search

   # name servers
   if new_dhcp6_name_servers:
      populateNameServers( new_dhcp6_name_servers, dhcpStatus.intfDhcp6Status )

   # dhcp6 option 59
   if new_dhcp6_bootfile_url:
      dhcpStatus.intfDhcp6Status[ interface ].bootFileName = new_dhcp6_bootfile_url

   # dhcp6 sztp-redirect option 136
   if new_dhcp6_sztp_redirect:
      populateSztpServers( new_dhcp6_sztp_redirect, dhcpStatus.intfDhcp6Status )

   # dhcp6 sntp-servers option 31
   if new_dhcp6_sntp_servers:
      populateNtpServers( new_dhcp6_sntp_servers, dhcpStatus.intfDhcp6Status )

   # Note: genId has to be last assigned attr as this triggers the
   # state machine
   dhcpStatus.intfDhcp6Status[ interface ].genId += 1

   return 0

def dhconfig6BoundHandler():
   t6( f'{getFnName()}' )

   if not new_ip6_address or not new_ip6_prefixlen:
      return 0

   # XXX-sarangs
   # Still to be implemented
   # Do DAD Bug 210715 tracks this
   doDAD()

   prefix = new_ip6_prefixlen
   # ip6 address and mask
   dhcpStatus.intfDhcp6Status[ interface ].addrWithMask = Arnet.IpGenAddrWithMask(
                                                  new_ip6_address + '/' + prefix )
   dhcpStatus.intfDhcp6Status[ interface ].af = "ipv6Stateful"

   return getDhcp6EnvParams()

def dhconfig6RenewHandler():
   t6( f'{getFnName()}' )
   intfDhcp6Status = dhcpStatus.intfDhcp6Status[ interface ]
   # This function can get called either in the stateful DHCP case when the
   # lease renewal happens or in the stateless DHCP case even in the first
   # DHCP transaction. In case of stateful DHCP, af would have already been
   # set by the BOUND handler. Avoid overwriting it if it's already set
   if intfDhcp6Status.af == "ipUnknown":
      intfDhcp6Status.af = "ipv6Stateless"
   return getDhcp6EnvParams()

def dhconfigHandler():
   t6( f'{getFnName()}' )
   if reason in [ 'BOUND', 'REBOOT' ]:
      dhconfigBoundHandler()
   return 0

def dhconfig6Handler():
   t6( f'{getFnName()}' )
   if reason == 'BOUND6':
      dhconfig6BoundHandler()
   # Dhclient if started with -6 -S option always invokes the zerotouch-dhclient
   # script with RENEW6 reason
   elif reason == 'RENEW6':
      dhconfig6RenewHandler()
   return 0

def downHandler():
   t6( f'{getFnName()} nothing to do' )
   return 0

def timeoutHandler():
   t6( f'{getFnName()} nothing to do' )
   return 0

def cleanupHandler():
   t6( f'{getFnName()}' )
   for intf in dhcpStatus.intfDhcpStatus:
      t6( f"{getFnName()} del {intf} intfDhcpStatus" )
      del dhcpStatus.intfDhcpStatus [ intf ]

   t6( f'{getFnName()} deleted {len(dhcpStatus.intfDhcpStatus)} intfDhcpStatus' )
   return 0

def cleanup6Handler():
   t6( f'{getFnName()}' )
   for intf in dhcpStatus.intfDhcp6Status:
      t6( f"{getFnName()} del {intf} intfDhcp6Status" )
      del dhcpStatus.intfDhcp6Status [ intf ]

   t6( f'{getFnName()} deleted {len(dhcpStatus.intfDhcp6Status)} intfDhcp6Status' )
   return 0

dhcpoptions = { 'PREINIT':  preinitHandler,
                'ARPCHECK': arpcheckHandler,
                'ARPSEND':  arpcheckHandler,
                'BOUND':    dhconfigHandler,
                'RENEW':    dhconfigHandler,
                'REBIND':   dhconfigHandler,
                'REBOOT':   dhconfigHandler,
                'EXPIRE':   downHandler,
                'FAIL':     downHandler,
                'RELEASE':  downHandler,
                'STOP':     downHandler,
                'TIMEOUT':  timeoutHandler,
                'CLEANUP':  cleanupHandler }

dhcp6options = { 'PREINIT6': preinitHandler,
                 'BOUND6':   dhconfig6Handler,
                 'RENEW6':   dhconfig6Handler,
                 'REBIND6':  dhconfig6Handler,
                 'DEPREF6':  dhconfig6Handler,
                 'EXPIRE6':  downHandler,
                 'RELEASE6': downHandler,
                 'STOP6':    downHandler,
                 'CLEANUP6': cleanup6Handler }

def mountSysdb():
   t6( f'{getFnName()}' )
   global dhcpStatus
   global sysLoggingConfig

   # mount sysdb
   sysname = os.environ.get( "SYSNAME", "ar" )
   pathHelper = SysdbPathHelper( sysname )

   dhcpStatus = pathHelper.getEntity( "zerotouch/dhcp/status" )
   if not dhcpStatus:
      t6( f'{getFnName()} "Failed to mount dhcpStatus"' )
      raise Exception( "Failed to mount dhcpStatus" )

   sysLoggingConfig = pathHelper.getEntity( "sys/logging/config" )

def main():
   global logFd
   logFd = openLogFile()

   t6( f'{getFnName()} starting...' )

   # Extract known environment variables
   getGenEnvVars()

   #XXX: HACK: Don't do anything for 'PREINIT'
   # PREINIT is called for every interface and
   # mounting sysdb 64 (or 384!) times is just
   # a bad idea.
   if not reason or reason == 'PREINIT' or reason == 'PREINIT6':
      t6( f'{getFnName()} reason={reason}, nothing to do' )
      logFd.close()
      return

   # Mount sysdb
   mountSysdb()

   dhclientVersion = 'v6' if "6" in reason else 'v4'

    # Create dhcp and dhcp6 interface status, if one doesn't exist
   if interface:
      if dhclientVersion == 'v4':
         t6( f'{getFnName()} create v4 intf({interface}) status' )
         dhcpStatus.newIntfDhcpStatus( interface )
      elif dhclientVersion == 'v6':
         t6( f'{getFnName()} create v6 intf({interface}) status' )
         dhcpStatus.newIntfDhcp6Status( interface )
   else:
      t6( f'{getFnName()} no {dhclientVersion} intf specified' )

   # Dhclient invokes dhclient-script with reason set to BOUND, RENEW etc. while
   # Dhclient v6 sets the reason to BOUND6, RENEW6 etc.
   # Handle 'reason'
   retCode = 0

   if dhclientVersion == 'v4':
      if reason in dhcpoptions:
         retCode = dhcpoptions[ reason ]()
      if interface:
         t6( f'{getFnName()} v4 intf({interface}) reason={reason}' )
         dhcpStatus.intfDhcpStatus[ interface ].reason = reason
   elif dhclientVersion == 'v6':
      if reason in dhcp6options:
         retCode = dhcp6options[ reason ]()
      if interface:
         t6( f'{getFnName()} v6 intf({interface}) reason={reason}' )
         dhcpStatus.intfDhcp6Status[ interface ].reason = reason

   Tac.flushEntityLog()
   logFd.close()

   # pylint: disable=protected-access
   os._exit( retCode )

if __name__ == "__main__":
   main()
