# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

# This library stores common items used in SysMgr

import QuickTrace, Tac
import ManagedSubprocess
from SuperServer import defaultTimeout
import os
import re
import Tracing
from MgmtSecurityLib import mgmtSecurityConfigType
from SslCertKey import USE_FIPS_FLAG
from IpLibConsts import DEFAULT_VRF

# pkg-deps: import EpochConsts

qv = QuickTrace.Var
qt0 = QuickTrace.trace0
t0 = Tracing.trace0
t3 = Tracing.trace0
t8 = Tracing.trace8

keygen = '/usr/bin/ssh-keygen'
keyGenFaultPoint = False
keyGenSleepFunc = False
shredTimeout = defaultTimeout * 3

sshKeyFolder = '/persist/secure/'
# RSA key locations
rsaKeyPath = sshKeyFolder + 'ssh_host_rsa_key'
rsaKeyPathPublic = sshKeyFolder + 'ssh_host_rsa_key.pub'

# DSA key locations
dsaKeyPath = sshKeyFolder + 'ssh_host_dsa_key'
dsaKeyPathPublic = sshKeyFolder + 'ssh_host_dsa_key.pub'

# ED25519 key locations
ed25519KeyPath = sshKeyFolder + 'ssh_host_ed25519_key'
ed25519KeyPathPublic = sshKeyFolder + 'ssh_host_ed25519_key.pub'

# ECDSA key locations
ecdsa521KeyPath = sshKeyFolder + 'ssh_host_ecdsa_key'
ecdsa521KeyPathPublic = sshKeyFolder + 'ssh_host_ecdsa_key.pub'
ecdsa256KeyPath = sshKeyFolder + 'ssh_host_ecdsa_nistp256_key'
ecdsa256KeyPathPublic = sshKeyFolder + 'ssh_host_ecdsa_nistp256_key.pub'

AuthenMethod = Tac.Type( "Mgmt::Ssh::AuthenMethod" )

# Map Tac Enum to Openssh method names in sshd config
opensshMethodMap = {
   AuthenMethod.password: 'password',
   AuthenMethod.publicKey: 'publickey',
   AuthenMethod.keyboardInteractive: 'keyboard-interactive'
}

# Map Ssh Authentication Method CLI Tokens to Tac Enum and vice versa
authenCliTokenToTacType = {
   "password": AuthenMethod.password,
   "public-key": AuthenMethod.publicKey,
   "keyboard-interactive": AuthenMethod.keyboardInteractive,
}
authenTacTypeToCliToken = {
   value: key for key, value in authenCliTokenToTacType.items()
}

keyTypeToPath = { 'rsa' : rsaKeyPath,
                  'dsa' : dsaKeyPath,
                  'ed25519' : ed25519KeyPath,
                  'ecdsa_nistp521' : ecdsa521KeyPath,
                  'ecdsa_nistp256' : ecdsa256KeyPath,
                }

# The first param in the list is default
keyTypeToLegalParams = {
   'rsa': {
       "keysize": [ '2048', '4096' ],
   },
}

# Map Cli key type to Tac enum representation
cliKeyTypeToTac = { 'rsa' : 'rsa',
                    'dsa' : 'dsa',
                    'ed25519' : 'ed25519',
                    'ecdsa-nistp521' : 'ecdsa_nistp521',
                    'ecdsa-nistp256' : 'ecdsa_nistp256',
                  }
# Inverse mapping of above. Assumes no collisions in mappings
tacKeyTypeToCliKey = { tacRep : cliRep for cliRep, tacRep in \
                                           cliKeyTypeToTac.items() }

legacyKeyTypeToAliasedParams = {
   # Map legacy keyalgo alias to new keyalgo name and key parameters
   'ecdsa': ( 'ecdsa-nistp521', {} )
}

def updateAuthenticationMethods( sshConfig, newMethods ):

   newMethods = sorted( newMethods )

   for idx, methodList in enumerate( newMethods ):
      currentMethod = sshConfig.authenticationMethodList.get( idx )

      if currentMethod is not None and \
              currentMethod.method.values() == methodList:
         continue

      newMethodList = sshConfig.authenticationMethodList.newMember( idx )
      newMethodList.method.clear()
      for m in methodList:
         newMethodList.method.enq( m )

   keysToRemove = \
      [ k for k in sshConfig.authenticationMethodList if k >= len( newMethods ) ]

   for key in keysToRemove:
      del sshConfig.authenticationMethodList[ key ]

def shredSshKey( path ):
   """
   Shred the ssh private key and associated public key
   located at path and path.pub respectively.
   Function returns when scrub call is done.

   Return type is True if both calls to scrub worked and
   False otherwise.
   """
   t8( "shredSshKey: path:%s" % ( path, ) )
   pubPath = path + '.pub'
   try:
      for filepath in ( path, pubPath ):
         if os.path.exists( filepath ):
            t3( "Shredding the file at %s" % filepath )
            Tac.run( [ 'scrub', '--remove', '--no-signature', filepath ],
                     asRoot=True,
                     timeout=shredTimeout )
         else:
            t3( "The file being checked for scrubbing at %s DNE" % filepath )
   except Tac.Timeout as e: # pylint: disable=unused-variable
      qt0( 'Warning: timedout while trying to scrub ssh keys %s' % ( qv( path ) ))
      t0( 'Warning: timedout while trying to scrub ssh keys %s' % ( path, ))
      return False
   except Tac.SystemCommandError as e:
      qt0( 'Warning, command failed:', qv( str( e ) ) )
      t0( 'Warning, command failed:', str( e ) )
      return False
   return True

def generateNewSshKeyComment( chassisAddr, oldComment, cvxaddr="" ):
   commentStr = "chassisAddr={chassisAddr}{cvxAddrFormat} {oldComment}"
   cvxAddrFormat = ""
   if cvxaddr:
      # Add one space to preserve previous format
      cvxAddrFormat = " cvxAddr=%s" % cvxaddr
   commentStrFilled = commentStr.format( chassisAddr=chassisAddr,
                                         cvxAddrFormat=cvxAddrFormat,
                                         oldComment=oldComment )
   return commentStrFilled

def generateSshKeysWithComment( keyType, keyParams, path, comment, fipsMode=False ):
   """
   This function will scrub and then regenerate a new ssh key
   of keyType that has the comment given as an argument.
   The return type of the function is a ManagedSubprocess for the
   ssh-keygen function that does the key generation. The caller is
   responsible for making sure the function completes and handling
   time-outs.
   """
   t8( "Inside generateSshKeysWithComment keyType:", keyType,
      " keyParams:", str( keyParams ), " path:", path, " comment:", comment,
      " fipsMode:", fipsMode )
   # Translate key type to something keygen can use
   keyTypeKeygen = keyType
   if "ecdsa" in keyTypeKeygen:
      # For ecdsa key types, specify the curve via -b
      keyTypeKeygen = "ecdsa"
   assert keyType in keyTypeToPath
   # ssh-keygen will freeze up if the file already exists
   assert not os.path.isfile( path )
   if keyGenFaultPoint:
      t3( "keyGenFaultPoint hit, returning early" )
      return None
   if keyGenSleepFunc:
      t3( "keyGenSleepFunc hit, returning sleep function" )
      sleepFunc = ManagedSubprocess.Popen( [ "sleep", "600" ],
            stdout=ManagedSubprocess.PIPE,
            stderr=ManagedSubprocess.PIPE )
      return sleepFunc
   keygenArgs =  [ keygen, "-q", "-t", keyTypeKeygen, "-f", path,
                   "-m", "PEM", "-N", "", "-C", comment ]
   if keyType == "rsa":
      keygenArgs.extend( [ "-b", keyParams[ 'keysize' ] ] )
   elif keyType == "ecdsa_nistp256":
      keygenArgs.extend( [ "-b", "256" ] )
   elif keyType == "ecdsa_nistp521":
      keygenArgs.extend( [ "-b", "521" ] )
   if fipsMode and USE_FIPS_FLAG:
      keygenArgs.append( "--fips" )
   keygenProc = ManagedSubprocess.Popen( keygenArgs,
                                         stdout=ManagedSubprocess.PIPE,
                                         stderr=ManagedSubprocess.PIPE )
   return keygenProc

class MgmtSecurityConfigReactor( Tac.Notifiee ):
   notifierTypeName = mgmtSecurityConfigType

   def __init__( self, notifier, mgr ):
      Tac.Notifiee.__init__( self, notifier )
      self.mgr_ = mgr

   @Tac.handler( 'entropySourceHardware' )
   def handleEntropySourceHardware( self ):
      self.mgr_.handleMgmtSecurity()

   @Tac.handler( 'entropySourceHaveged' )
   def handleEntropySourceHaveged( self ):
      self.mgr_.handleMgmtSecurity()

   @Tac.handler( 'entropySourceJitter' )
   def handleEntropySourceJitter( self ):
      self.mgr_.handleMgmtSecurity()

class MgmtSecurityStatusReactor( Tac.Notifiee ):
   notifierTypeName = "Mgmt::Security::Status"

   def __init__( self, notifier, mgr ):
      Tac.Notifiee.__init__( self, notifier )
      self.mgr_ = mgr

   @Tac.handler( 'entropySourceHardwareEnabled' )
   def handleEntropySourceHardware( self ):
      self.mgr_.handleMgmtSecurity()

   @Tac.handler( 'entropySourceHavegedEnabled' )
   def handleEntropySourceHaveged( self ):
      self.mgr_.handleMgmtSecurity()

   @Tac.handler( 'entropySourceJitterEnabled' )
   def handleEntropySourceJitter( self ):
      self.mgr_.handleMgmtSecurity()

class MgmtSecurityMgr:
   def __init__( self, mgmtSecurityConfig, mgmtSecurityStatus, agent ):
      self.config_ = mgmtSecurityConfig
      self.status_ = mgmtSecurityStatus
      self.configReactor_ = MgmtSecurityConfigReactor( mgmtSecurityConfig,
                                                       self )
      self.statusReactor_ = MgmtSecurityStatusReactor( mgmtSecurityStatus,
                                                       self )
      self.agent_ = agent

   def sync( self ):
      # The subclass must implement this
      raise NotImplementedError

   def handleMgmtSecurity( self ):
      if self.entropySourceOk():
         self.sync()

   def entropySourceOk( self ):
      # Ssh services should be disabled until entropy source status
      # is consistent with config.
      if not self.agent_.active():
         # TODO: remove this check if hardware entropy works on standby
         return True
      hardwareReady = True
      havegedReady = True
      jitterReady = True
      if self.config_.entropySourceHardware:
         hardwareReady = self.status_.entropySourceHardwareEnabled
      if self.config_.entropySourceHaveged:
         havegedReady = self.status_.entropySourceHavegedEnabled
      if self.config_.entropySourceJitter:
         jitterReady = self.status_.entropySourceJitterEnabled
      return ( hardwareReady and havegedReady and jitterReady )

def getIp4SrcIntfAddr( networkUrlConfig, ipStatus, protocol, vrfName=None ):
   """
   For a given protocol, check if there is a source interface and
   return the IPv4 address on the interface. Return None if the
   requirements can not be fulfilled.
   """
   if vrfName is None:
      vrfName = DEFAULT_VRF

   # We return the primary IPv4 address on the interface.
   protConf = networkUrlConfig.protocolConfig.get( protocol )
   if not protConf:
      return None
   srcIntf = protConf.srcIntf.get( vrfName )
   if not srcIntf:
      return None
   srcIntfName = srcIntf.intfName
   if srcIntfName:
      srcIpIntfStatus = ipStatus.ipIntfStatus.get( srcIntfName )
      if( srcIpIntfStatus and \
          srcIpIntfStatus.activeAddrWithMask.address != "0.0.0.0" ):
         return srcIpIntfStatus.activeAddrWithMask.address
   return None

def getIp6SrcIntfAddr( networkUrlConfig, ip6Status, protocol, vrfName=None ):
   """
   For a given protocol, check if there is a source interface and
   return the IPv6 address on the interface. Return None if the
   requirements can not be fulfilled.
   """

   if vrfName is None:
      vrfName = DEFAULT_VRF
   # We return the first IPv6 address we find
   # with global scope on the interface.
   protConf = networkUrlConfig.protocolConfig.get( protocol )
   if not protConf:
      return None
   srcIntf = protConf.srcIntf.get( vrfName )
   if not srcIntf:
      return None
   srcIntfName = srcIntf.intfName
   if srcIntfName:
      srcIp6IntfStatus = ip6Status.intf.get( srcIntfName )
      if srcIp6IntfStatus:
         for ip6addr in srcIp6IntfStatus.addr:
            if ip6addr.address.isLinkLocal: # pylint: disable=no-else-continue
               continue
            else:
               return ip6addr.address.stringValue
   return None

def checkRekeyDataLimit( amount, rekeyUnit ):
   """
   Checks if rekey data limit is within openssh limits.
   This is a result of openssh project bug 2264 and is fixed
   in openssh-7.1.

   Returns True if value is within limits and False otherwise.
   """
   minRekeyData = 16
   maxRekeyData = 0xFFFFFFFF
   multiplier = 1024
   if rekeyUnit == "kbytes":
      multiplier **= 1
   elif rekeyUnit == "mbytes":
      multiplier **= 2
   elif rekeyUnit == "gbytes":
      multiplier **= 3
   else:
      assert False, "Unknown rekey unit passed in"
   fullAmount = amount * multiplier
   # pylint: disable-next=superfluous-parens
   return ( minRekeyData < fullAmount < maxRekeyData )

def getSshHostKeyFingerprint( path, hashAlgo='SHA-256' ):
   t8( "getSshHostKeyFingerprint: path:%s" % path )
   # For now, we always use SHA-256.
   assert hashAlgo == 'SHA-256'
   # Make sure the ssh host key exists
   assert os.path.exists( path )
   output = Tac.run( [ keygen, "-E", "sha256", "-lf", path ],
                     asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )
   match = re.search( r"SHA256:(\S+)", output )
   if match:
      fingerprint = match.group( 1 )
   else:
      fingerprint = "<fail to get ssh host key fingerprint: %s>" % output
   return fingerprint

def netnsNameWithUniqueId( netnsName ):
   # Return a netnsName with unique ID that can be used by Xinetd to indicate
   # a new networkspace from previous generation due to delete-and-create.
   # We just use the mtime of the netns file. Python 2.x converts the high-precision
   # timestamp to a float number, but it's good enough.
   #
   # Note nsutil's netnsRunPath/setnsByName() will ignore strings after '/@' inside
   # the name.
   if netnsName and netnsName != DEFAULT_VRF:
      try:
         st = os.stat( "/var/run/netns/%s" % netnsName )
         return "%s/@%s" % ( netnsName, st.st_mtime )
      except OSError:
         pass
   return netnsName

def defaultSshUserConfig( user, sshConfig ):
   '''
      Delete/default SSH config.
      1. If we have transitioned to the new CLI for keys, reset keys.
      2. If we have transitioned to the new CLI for principals, reset prinicipals.
      3. If both keys and principals are empty, delete the SSH user configuration
         else default non-key/non-principal SSH configuration.
   '''
   userConfig = sshConfig.user.get( user )
   if userConfig:
      # Reset the flag which indicates that the operator had explicitly entered
      # the user config since the operator is deleting all SSH config
      userConfig.userModeEntered = False

      if sshConfig.useNewSshCliKey:
         userConfig.sshAuthKeys.clear()
         userConfig.publishedSshAuthKey = ""

      if sshConfig.useNewSshCliPrincipal:
         userConfig.sshAuthPrincipals.clear()

      # If no more SSH keys or principals, delete the SSH user
      if not( userConfig.sshAuthKeys or userConfig.sshAuthPrincipals ):
         del sshConfig.user[ user ]
      else:
         # Default non-SSH key/principal configuration
         userConfig.userTcpForwarding = userConfig.userTcpForwardingDefault

def processSshOptions( option, optionValues ):
   '''
      Utility function to process SSH-options
   '''
   processedValues = []
   for value in sorted( optionValues ):
      # If space in value, quote it
      v = f"\"{value}\"" if " " in value else value
      processedValues.append( f"Option: {option} {v}" )
   return processedValues

def outputOptions( name, config, optionsStr ):
   '''
      Utility function to form output string for SSH-options
   '''
   finalStr = optionsStr
   for option in sorted( config[ name ].sshOptions ):
      optionConfig = config[ name ].sshOptions[ option ]
      if optionConfig.sshValues:
         # OptionValues is of the form
         # Option: OptionName OptionValue
         finalStr += "\n".join( processSshOptions( option,
            optionConfig.sshValues ) ) + "\n"
      else:
         finalStr += f"Option: {option}\n"
   return finalStr
