# Copyright (c) 2007, 2008, 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import collections
from datetime import datetime
from SysConstants.in_h import IPPROTO_UDP
import itertools
import re

import CliPlugin.TechSupportCli
from CliPlugin import AclCli
from CliPlugin import ClockCli
from CliPlugin import IntfCli
from CliPlugin.NtpModels import NtpPeer, NtpPeers
import AclCliLib
import AclLib
import Cell
import ConfigMount
import DscpCliLib
import HostnameCli
from IpLibConsts import DEFAULT_VRF
import LazyMount
import NtpLib
from NtpLib import authModeNone, authModeAll, authModeServers
from PyWrappers.NtpStat import ntpstatBinary
import Tac

from ArnetModel import IpGenericAddress
from CliModel import (
   Enum,
   Int,
   Model,
   Str
   )

from Arnet import IpGenAddr

ntpConfig = None
ntpStatus = None
ntpServerStatus = None
allVrfStatusLocal = None
aclConfig = None
aclCpConfig = None
aclStatus = None
aclCheckpoint = None
dscpConfig = None

def validAddress( address ):
   try:
      IpGenAddr( address )
   except ValueError:  # not a valid address
      return False
   return True

def referenceIsLocal():
   return ( ntpConfig.localStratum != ntpConfig.localStratumDisabled
            and not ntpConfig.server )

# ------------------------------------------
# Ntp show commands :
#    show ntp status
#    show ntp associations
# from enable mode

def ntpEnabled( mode ):
   # For Ntp to be enabled:
   # 1. The clock source must be configured to 'ntp'
   # 2. Either a server needs to be configured or local reference mode is on
   # 3. The current cell must be active

   if ClockCli.clockConfig.source != 'ntp':
      return False
   if not ntpConfig.server and not referenceIsLocal():
      return False
   if ( mode.session_ and
        mode.session_.entityManager_.redundancyStatus().mode == 'standby' ):
      return False

   return True

def ntpStat( mode ):
   if ntpEnabled( mode ):
      try:
         status = NtpLib.runMaybeInNetNs(
            ntpConfig, allVrfStatusLocal,
            [ ntpstatBinary() ], stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )
         return status
      except Tac.SystemCommandError as e:
         if e.error == 1:
            # Unsynchronized
            return e.output
         elif e.error == 2:
            # Could not connect to the NTP daemon, probably because
            # ntpd is in the process of being started
            return 'NTP starting...'
         else:
            raise
      except NtpLib.NetNsError:
         return "NTP is configured in a non-existing VRF."
   else:
      return 'NTP is disabled.'

synchronizedRe = re.compile(
     r"^synchronised to (?:NTP server [(](?P<server>.*?)[)]|(?P<other>.*?)) "
     "at stratum" )

def ntpServerClockSourceHook( mode ):
   status = ntpStat( mode )
   m = synchronizedRe.search( status )
   if m:
      source = m.group( "server" ) or m.group( "other" )
      if validAddress( source ):
         return "NTP server (%s)" % source # pylint: disable=consider-using-f-string
   return ""

class ShowNtpStatusCmdModel( Model ):
   status = Enum( values=( "disabled",
                           "synchronised",
                           "unsynchronised",
                           "starting",
                           "local",
                           "non-existing" ),
                  help="Status of NTP synchronisation service" )
   server = IpGenericAddress( help="Time source (NTP server or reference clock) "
                              "to which the system clock is currently synchronised",
                              optional=True )
   stratum = Int( help="Indicates the level of the server in the NTP hierarchy "
                  "- see RFC 5905, 7.3",
                  optional=True )
   pollingInterval = Int( help="Polling interval of the peer (in seconds)",
                          optional=True )
   maxEstimatedError = Int( help="Maximum estimated error of the clock "
                            "(in milliseconds)",
                            optional=True )
   _details = Str( help="Full details of NTP synchronisation service" )

   def render( self ):
      print( self._details )

statusRe = re.compile(
   r".*(?P<Unsynchronised>unsynchronised)|"
   ".*(?P<Synchronised>synchronised)"
   "|.*.*(?P<Starting>starting)"
   "|.*.*(?P<Error>non-existing)"
   "|.*.*(?P<Disabled>disabled)" )
stratumRe = re.compile(
   r".*stratum (?P<Stratum>\d+)" )
accuracyRe = re.compile(
   r".*(?:time correct to within (?P<Accuracy>\d+).*)" )
freqRe = re.compile(
   r".*(?:polling server every (?P<Freq>\d+).*)" )

def doShowNtpStatus( mode, args ):
   if ntpEnabled( mode ) and referenceIsLocal():
      status = 'local'
      stratum = ntpConfig.localStratum
      details = f"NTP is operating in local reference mode (stratum {stratum})."
      return ShowNtpStatusCmdModel( status=status,
                                    stratum=stratum,
                                    _details=details )

   status = "disabled"
   source = stratum = freq = accuracy = None
   details = ntpStat( mode )

   m = statusRe.search( details )
   if m:
      status = m.group( "Unsynchronised" ) or\
               m.group( "Synchronised" ) or\
               m.group( "Starting" ) or \
               m.group( "Error" ) or \
               m.group( "Disabled" )

   m = synchronizedRe.search( details )
   if m:
      source = m.group( "server" ) or m.group( "other" )
      if not validAddress( source ):
         source = None

   m = stratumRe.search( details )
   if m:
      stratum = int( m.group( "Stratum" ) )

   m = accuracyRe.search( details )
   if m:
      accuracy = int( m.group( "Accuracy" ) )

   m = freqRe.search( details )
   if m:
      freq = int( m.group( "Freq" ) )

   return ShowNtpStatusCmdModel( status=status,
                                 server=source,
                                 stratum=stratum,
                                 pollingInterval=freq,
                                 maxEstimatedError=accuracy,
                                 _details=details )

def doShowNtpAssociations( mode, args ):
   ntpq = '/usr/sbin/ntpq'

   def getNtpPeersDict( fields=None ):
      if fields is None:
         fields = ( 'srcadr', 'srcport', 'refid', 'stratum', 'rec', 'ppoll',
                    'reach', 'delay', 'offset','jitter' )

      ntpq_peers = NtpLib.runMaybeInNetNs( ntpConfig, allVrfStatusLocal,
         # pylint: disable-next=consider-using-f-string
         [ ntpq, '-n' ] + [ '-c rv {} {}'.format( aid, ','.join( fields ) ) for
                            aid in conditions ] + [ '127.0.0.1' ],
         stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )
      ntpq_peer_list = ( fields[0] + p for p in
                         ntpq_peers.replace( '\n', ' ' ).split( fields[0] )[1:] )
      for peer_vars in ntpq_peer_list:
         peer_var_tokens = itertools.chain( *( var.rsplit( ', ', 1 ) for var in
                                               peer_vars.split( '=' ) ) )
         # group peer_var_tokens two by two
         yield dict( zip( *[ iter( peer_var_tokens ) ] * 2 ) )

   def getNtpPeerFromDict( assoc_id, peer_dict ):
      peer = NtpPeer()
      peer.condition = conditions[ assoc_id ]
      peer.peerIpAddr = peer_dict[ 'srcadr' ]
      peer.refid = peer_dict[ 'refid' ]
      if len( peer.refid ) == 4:
         peer.refid = f'.{peer.refid}.'
      peer.stratumLevel = int( peer_dict[ 'stratum' ] )
      if peer.peerIpAddr.isMulticast:
         peer.peerType = 'broadcast'
      elif peer.peerIpAddr.isLinkLocal:
         peer.peerType = 'local'
      else:
         peer.peerType = 'unicast'
      rec_time_str = peer_dict.get( 'rec' )
      if rec_time_str != "(no time)":
         date_str = rec_time_str.split( '  ', 1 )[1]
      else:
         date_str = None
      date_format = '%a, %b %d %Y %H:%M:%S.%f'
      peer.lastReceived = float( ( datetime.strptime( date_str, date_format )
                                   if date_str else datetime.now() ).
                                 strftime( '%s' ) )
      peer.pollInterval = 2 ** int( peer_dict['ppoll'] )
      peer.reachabilityHistory = [ x == '1' for x in
                                   bin( int( peer_dict['reach'], 8 ) )[2:] ]
      peer.delay = float( peer_dict['delay'] )
      peer.offset = float( peer_dict['offset'] )
      peer.jitter = float( peer_dict['jitter'] )
      return peer

   if ntpEnabled( mode ):
      # ntpq does not return with an error code if it is unable to
      # connect to ntpd. Instead, it prints out:
      #    'ntpq: read: Connection refused'.
      # This is not the most helpful error message, so we have to do the
      # following output matching.
      try:
         # ntpq will try to resolve 'localhost' if the server's IP address
         # is not specified. We want to bypass resolution because it can
         # cause problems for IPv6 and because it's unnecessary given that
         # we always know the server's address.
         associations = NtpLib.runMaybeInNetNs(
            ntpConfig, allVrfStatusLocal,
            [ ntpq, '-c', 'associations', '127.0.0.1' ],
            stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )
         if re.search( 'Connection refused', associations ):
            # Could not connect to the NTP daemon, probably because
            # ntpd is in the process of being started
            print( "NTP starting..." )
         else:
            # OrderedDict avoids any uncertainty about consistent iteration
            # order over conditions in comprehensions and zip.
            conditions = collections.OrderedDict(
               re.findall( r'^\s*\d+\s+(\d+)\s+\S+\s+\S+\s+\S+\s+\S+\s+(\S+)',
                           associations, re.MULTILINE ) )
            if not conditions:
               return NtpPeers( peers={} )

            ntpq_hostnames = NtpLib.runMaybeInNetNs(
               ntpConfig, allVrfStatusLocal,
               [ ntpq ] + [ f'-c rv {aid} srcadr' for aid in conditions ]
                        + [ '127.0.0.1' ],
               stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )
            hostnames = dict( zip( conditions,
                                   re.findall( r'^srcadr=(.+?)$', ntpq_hostnames,
                                               re.MULTILINE ) ) )
            return NtpPeers( peers={ hostnames[aid]:
                                     getNtpPeerFromDict( aid, peer_dict ) for
                                     aid, peer_dict in
                                     zip( conditions, getNtpPeersDict() ) } )
      except NtpLib.NetNsError:
         print( "NTP is configured in a non-existing VRF." )
   else:
      print( "NTP is disabled." )
   return NtpPeers()

# ------------------------------------------
# Ntp config commands :
#    no ntp
#    ntp server [vrf <vrf-name>] <hostname> [prefer] [version <versionNum>]
#       [source <interface>] [source-address <address>] [burst] [iburst] \
#       [minpoll <interval>] [maxpoll <interval>] [key <key id>] [refresh]
#    no ntp server <hostname>
#    no ntp bind
#    ntp bind <interfaces>
#    no ntp bind <interfaces>
#    ntp serve all
#    no ntp serve all
#    ntp authentication-key <key id> (md5|sha1) [0|7] <password>
#    no ntp authentication-key <key id>
#    ntp authenticate [servers]
#    no ntp authenticate [servers]
#    ntp trusted-key { <key id 1>, <key id 2>, ... }
#    no ntp trusted-key
#    ntp local [stratum STRATUM]
#    no ntp local
# from global config mode

def doDisableNtp( mode, args ):
   # Delete all server configurations
   ntpConfig.server.clear()

   ntpConfig.serverModeDisabledIntf.clear()
   ntpConfig.serverModeEnabledIntf.clear()
   ntpConfig.serverModeEnabledDefault = False

   ntpConfig.localStratum = ntpConfig.localStratumDisabled
   noDscp( mode )

def doDeleteNtpServer( mode, args ):
   ipAddrOrHostname = args[ 'HOST' ]
   vrfName = args.get( 'VRF' )
   vrfAware = True
   if vrfName is None:
      vrfName = DEFAULT_VRF
      vrfAware = False
   # Deleting the server
   vrfAndHost = Tac.Value( "Ntp::VrfAndHost", vrfName, ipAddrOrHostname )
   if vrfAndHost not in ntpConfig.server:
      warningPrefix = "server "
      if vrfAware:
         warningPrefix += vrfName + "/"
      mode.addWarning( warningPrefix + ipAddrOrHostname + " is not configured" )
   else:
      del ntpConfig.server[ vrfAndHost ]
   updateDscpRules()

def doConfigAuth( mode, args ):
   servers = 'servers' in args
   ntpConfig.authMode = authModeServers if servers else authModeAll

def doDisableAuth( mode, args ):
   ntpConfig.authMode = authModeNone

def doConfigTrustedKey( mode, args ):
   ntpConfig.trustedKeys = str( args[ 'KEY_RANGE' ] )

def doNoTrustedKey( mode, args ):
   ntpConfig.trustedKeys = ''

def doConfigNtpServer( mode, args ):
   vrfName = args.get( 'VRF', DEFAULT_VRF )
   ipAddrOrHostname = args[ 'HOST' ]

   # I don't warn about invalidity of the source interface here, because I expect
   # that it will be common to configure NTP before the interfaces are fully
   # configured.  Warnings could be a lot of unnecessary noise.

   # Do a sanity check on the hostname or IP address that the user
   # entered. If it doesn't appear to be legal, print a warning,
   # but don't reject the entry. (The specified hostname may not
   # yet have been configured in DNS, for example.)
   HostnameCli.resolveHostname( mode, ipAddrOrHostname, doWarn=True )

   # Check that the selected vrf is compatible with the other servers
   server = None
   for server in ntpConfig.server.values():
      break
   if server and server.vrf != vrfName:
      # pylint: disable-next=consider-using-f-string
      mode.addError( "All NTP servers must be in the same VRF. "
                     "Please remove time servers from VRF %s before continuing." %
                     server.vrf )
      return

   vrfAndHost = Tac.Value( "Ntp::VrfAndHost", vrfName, ipAddrOrHostname )
   ntpConfig.server.addMember( Tac.Value( "Ntp::Server",
                                          vrfAndHost=vrfAndHost,
                                          **args[ 'OPTIONS' ] ) )
   updateDscpRules()

def doConfigDefaultSourceIntf( mode, args ):
   sourceIntf = args[ 'INTF' ].name
   vrf = args.get( 'VRF', DEFAULT_VRF )
   # Re validation of the source interface: See the comment in doConfigNtpServer.
   ntpConfig.defaultSourceIntf = Tac.Value( "Ntp::VrfAndIntf",
                                            vrf=vrf, intf=sourceIntf )

def doRemoveDefaultSourceIntf( mode, args ):
   ntpConfig.defaultSourceIntf = Tac.Value( "Ntp::VrfAndIntf",
                                            vrf=DEFAULT_VRF, intf='' )

def doConfigAuthenKey( mode, args ):
   keyType = args.get( 'md5' ) or args[ 'sha1' ]
   ntpConfig.symmetricKey.addMember( Tac.Value( "Ntp::SymmetricKey",
                                                args[ 'KEY_ID' ],
                                                keyType,
                                                args[ 'KEY' ] ) )

def doDeleteAuthenKey( mode, args ):
   del ntpConfig.symmetricKey[ args[ 'KEY_ID' ] ]

def doNtpServeAll( mode, args ):
   vrfName = args.get( 'VRF' )
   if vrfName:
      ntpConfig.serveVrfName[ vrfName ] = True
      return
   ntpConfig.serverModeEnabledDefault = True

def doRemoveNtpServeAll( mode, args ):
   vrfName = args.get( 'VRF' )
   if vrfName:
      del ntpConfig.serveVrfName[ vrfName ]
      return
   ntpConfig.serverModeEnabledDefault = False

def doEnableLocalReferenceMode( mode, args ):
   ntpConfig.localStratum = args.get( 'STRATUM',
                                      ntpConfig.localStratumEnabledDefault )

def doDisableLocalReferenceMode( mode, args ):
   ntpConfig.localStratum = ntpConfig.localStratumDisabled

# Bind an interface to ntp.
# Note that this changes the meaning of "ntp bind" a little bit.
# As a remedy for the open file descriptors limit issue, it will
# still work fine, but now a "bound" interface is synonymous
# with a server-mode-enabled one. We should note in our
# documentation that this command should not be used with "ntp
# serve", as they'll step on each other's feet in a way that
# could be very confusing.
def doAddBoundIntfs( mode, args ):
   validIntfs = IntfCli.Intf.getAll( mode, args[ 'INTFS' ], config=True )
   for intf in validIntfs:
      ntpConfig.serverModeEnabledIntf[ intf.name ] = True

def doRemoveBoundIntfs( mode, args ):
   if intfs := args.get( 'INTFS' ):
      for intf in intfs:
         del ntpConfig.serverModeEnabledIntf[ intf ]
   else:
      ntpConfig.serverModeEnabledIntf.clear()

def enableRestarts( mode, args ):
   ntpConfig.restartsForcedOnIntfChanges = True

def disableRestarts( mode, args ):
   ntpConfig.restartsForcedOnIntfChanges = False

def setNtpIpAcl( mode, args ):
   if 'ip' in args:
      aclType = 'ip'
      aclName = args[ 'IP_ACL' ]
   else:
      aclType = 'ipv6'
      aclName = args[ 'IP6_ACL' ]

   assert aclType in args
   AclCliLib.setServiceAcl( mode, 'ntp', IPPROTO_UDP,
                            aclConfig, aclCpConfig, aclName, aclType=aclType,
                            # We re-fetch VRF here in order to ensure None for
                            # the default VRF.
                            vrfName=args.get( 'VRF' ),
                            port=[ AclLib.getServByName( IPPROTO_UDP, 'ntp' ) ],
                            tracked=True )

def noNtpIpAcl( mode, args ):
   aclType = 'ip' if 'ip' in args else 'ipv6'
   vrfName = args.get( 'VRF' )
   assert aclType in args
   AclCliLib.noServiceAcl( mode, 'ntp', aclConfig, aclCpConfig, None, aclType,
                           vrfName=vrfName )

def updateDscpRules():
   # We don't add iptables dscp rules (both v4 and v6) only if dscp is not
   # explicitly configured from CLI ie when dscpValue is dscpValueInvalid (64).
   # When dscp is configured as '48', we go ahead and add the iptables dscp rules
   # for v4 even though that's the default for NTPv4 packets( NTP open-source code
   # that we use set the dscp to 48 for NTPv4 packets ). For NTPv6 packets default
   # dscp is zero.
   dscpValue = ntpConfig.dscpValue
   if dscpValue == ntpConfig.dscpValueInvalid:
      del dscpConfig.protoConfig[ 'ntp' ]
      return

   protoConfig = dscpConfig.newProtoConfig( 'ntp' )
   ruleColl = protoConfig.rule
   ruleColl.clear()

   vrf = NtpLib.vrfInUse( ntpConfig )
   port = 123
   # server response
   DscpCliLib.addDscpRule( ruleColl, '0.0.0.0', port,
                           True, vrf, 'udp', dscpValue )
   DscpCliLib.addDscpRule( ruleColl, '::', port,
                           True, vrf, 'udp', dscpValue, v6=True )

   for server in ntpConfig.server.values():
      # client request to server
      DscpCliLib.addDscpRule( ruleColl, server.ipOrHost, port,
                              False, vrf, 'udp', dscpValue )
      DscpCliLib.addDscpRule( ruleColl, server.ipOrHost, port,
                              False, vrf, 'udp', dscpValue, v6=True )

def setDscp( mode, args ):
   ntpConfig.dscpValue = args[ 'DSCP' ]
   updateDscpRules()

def noDscp( mode, args=None ):
   ntpConfig.dscpValue = ntpConfig.dscpValueInvalid
   updateDscpRules()

def updateVrfInModel( aclListModel, summary ):
   vrfInUse = NtpLib.vrfInUse( ntpConfig )

   # Update the VRF of Service ACLs.
   for model in aclListModel.aclList:
      summaryModel = model if summary else model.summary
      # The ACL is active only if its VRF is the NTP server's VRF.
      if summaryModel.configuredVrfs == [ vrfInUse ]:
         summaryModel.activeVrfs = [ vrfInUse ]
      else:
         summaryModel.activeVrfs = []

   return aclListModel

def showNtpIpAcl( mode, args ):
   aclType = 'ip' if 'ip' in args else 'ipv6'
   allAclListModel = AclCli.showServiceAcl( mode,
                                 aclCpConfig,
                                 aclStatus,
                                 aclCheckpoint,
                                 aclType,
                                 args[ '<aclNameExpr>' ],
                                 serviceName='ntp' )

   if aclType == 'ip':
      aclListModel = allAclListModel.ipAclList
   else:
      aclListModel = allAclListModel.ipv6AclList
   updateVrfInModel( aclListModel, args[ '<aclNameExpr>' ][ 1 ] )
   return allAclListModel

def clearNtpIpAclCounters( mode, args ):
   aclType = 'ip' if 'ip' in args else 'ipv6'
   assert aclType in args
   AclCli.clearServiceAclCounters( mode,
                                   aclStatus,
                                   aclCheckpoint,
                                   aclType )

# ------------------------------------------
# Ntp config-if commands :
#    ntp serve
#    no ntp serve
#    default ntp serve

def doNtpServeIntf( mode, args ):
   name = str( mode.intf )
   del ntpConfig.serverModeDisabledIntf[ name ]
   ntpConfig.serverModeEnabledIntf[ name ] = True

def doRemoveNtpServeIntf( mode, args ):
   name = str( mode.intf )
   del ntpConfig.serverModeEnabledIntf[ name ]
   ntpConfig.serverModeDisabledIntf[ name ] = True

def doRemoveServerModeConfig( mode, args ):
   name = str( mode.intf )
   del ntpConfig.serverModeEnabledIntf[ name ]
   del ntpConfig.serverModeDisabledIntf[ name ]

# class registered in Plugin that will respond with appropriate
# changes to ntpConfig when an interface has been set to default
class NtpIntf( IntfCli.IntfDependentBase ):
   def setDefault( self ):
      del ntpConfig.serverModeEnabledIntf[ self.intf_.name ]
      del ntpConfig.serverModeDisabledIntf[ self.intf_.name ]

#-----------------------------------------------------------
# Register show ntp commands into "show tech-support".
#-----------------------------------------------------------
CliPlugin.TechSupportCli.registerShowTechSupportCmd(
   '2010-01-01 00:03:00',
   cmds=[ 'show ntp status' ],
   summaryCmds=[ 'show ntp status' ] )

CliPlugin.TechSupportCli.registerShowTechSupportCmd(
   '2012-09-25 13:33:47',
   cmds=[ 'show ntp associations' ],
   summaryCmds=[ 'show ntp associations' ] )

def Plugin( entityManager ):
   global allVrfStatusLocal
   global ntpConfig, ntpStatus, ntpServerStatus
   global aclConfig, aclCpConfig, aclStatus, aclCheckpoint, dscpConfig
   ntpConfig = ConfigMount.mount( entityManager, "sys/time/ntp/config",
                                  "Ntp::Config", "w" )
   allVrfStatusLocal = LazyMount.mount( entityManager,
                                        Cell.path( 'ip/vrf/status/local' ),
                                        'Ip::AllVrfStatusLocal', 'r' )

   IntfCli.Intf.registerDependentClass( NtpIntf )

   ntpStatus = LazyMount.mount( entityManager, "sys/time/ntp/status",
                                "Ntp::Status", "r" )
   ntpServerStatus = LazyMount.mount( entityManager, "sys/time/ntp/serverStatus",
                                      "Ntp::ServerStatusDir", "r" )
   aclConfig = ConfigMount.mount( entityManager, "acl/config/cli",
                                  "Acl::Input::Config", "w" )
   aclCpConfig = ConfigMount.mount( entityManager, "acl/cpconfig/cli",
                                  "Acl::Input::CpConfig", "w" )
   aclStatus = LazyMount.mount( entityManager, "acl/status/all",
                                "Acl::Status", "r" )
   aclCheckpoint = LazyMount.mount( entityManager, "acl/checkpoint",
                                    "Acl::CheckpointStatus", "w" )
   dscpConfig = ConfigMount.mount( entityManager, "mgmt/dscp/config",
                                   "Mgmt::Dscp::Config", "w" )

ClockCli.registerClockSourceHook( ntpServerClockSourceHook )
