#!/usr/bin/env python3
# Copyright (c) 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import CliParser
import CliCommand
import CliMatcher
import CliSaveBlock
from ctypes import CDLL, Structure, c_char, c_ulong, c_int, c_char_p, \
                   create_string_buffer, sizeof, byref, POINTER
from ctypes.util import find_library
import re
import Tracing


traceHandle_ = Tracing.Handle( "SecretCli" )
debug = traceHandle_.trace0

# The config in Sysdb to get the hash algorithm
defaultHashFunc = None
minPasswordLengthFunc = None

# Load native libcrypt.so
libcrypt = CDLL( find_library( "crypt" ) )
assert libcrypt, "Failed loading libcrypt.so library"

# Libcrypt defines, structures and functions
CRYPT_OUTPUT_SIZE = 384
CRYPT_MAX_PASSPHRASE_SIZE = 512
CRYPT_DATA_RESERVED_SIZE = 767
CRYPT_DATA_INTERNAL_SIZE = 30720
CRYPT_GENSALT_OUTPUT_SIZE = 192

# struct crypt_data {
#     char output[CRYPT_OUTPUT_SIZE];
#     char setting[CRYPT_OUTPUT_SIZE];
#     char input[CRYPT_MAX_PASSPHRASE_SIZE];
#     char reserved[CRYPT_DATA_RESERVED_SIZE];
#     char initialized;
#     char internal[CRYPT_DATA_INTERNAL_SIZE];
# };
class CRYPT_DATA( Structure ):
   _fields_ = [
      ( "output", c_char * CRYPT_OUTPUT_SIZE ),
      ( "setting", c_char * CRYPT_OUTPUT_SIZE ),
      ( "input", c_char * CRYPT_MAX_PASSPHRASE_SIZE ),
      ( "reserved", c_char * CRYPT_DATA_RESERVED_SIZE ),
      ( "initialized", c_char ),
      ( "internal", c_char * CRYPT_DATA_INTERNAL_SIZE )
   ]

# char * crypt_rn(const char * phrase, const char * setting,
#                 struct crypt_data * data, int size);
libcrypt.crypt_rn.argtypes = c_char_p, c_char_p, POINTER( CRYPT_DATA ), c_int
libcrypt.crypt_rn.restype = c_char_p

# char * crypt_gensalt_rn(const char * prefix, unsigned long count,
#                         const char * rbytes, int nrbytes, char * output,
#                         int output_size);
libcrypt.crypt_gensalt_rn.argtypes = c_char_p, c_ulong, c_char_p, \
                                     c_int, c_char_p, c_int
libcrypt.crypt_gensalt_rn.restype = c_char_p


# Regular expression that matches all the valid characters in a
# cleartext password (ie one entered at the 'Password:' prompt).
# pylint: disable=W1401
cleartextPasswdRe = ( r'[0-9a-zA-Z\!\@\#\$\%\^\&\*\(\)\-\_\=\+' +
                      '\\[\\]\\{\\}\\\\;\\:\'\"\\<\\>\\,\\.\\?\\/\\`\\~]+' )

# Regular expression matching various types of encrypted strings
# returned by crypt(3)
hashedStringCharsRe = r'[a-zA-Z0-9\.\/]'

md5EncryptedPasswdRe = r'\$1\$%s{1,8}[$]%s{22}' % \
                       ( hashedStringCharsRe, hashedStringCharsRe )

sha512EncryptedPasswdRe = r'\$6\$%s{1,16}[$]%s{86}' % \
                          ( hashedStringCharsRe, hashedStringCharsRe )

scryptEncryptedPasswdRe = r'\$7\$%s{11,97}\$%s{43}' % \
                           ( hashedStringCharsRe, hashedStringCharsRe )

yescryptEncryptedPasswdRe = r'\$y\$%s+\$%s{0,86}\$%s{43}' % \
                           ( hashedStringCharsRe, hashedStringCharsRe,
                             hashedStringCharsRe )

specialCharacters = r'(!"#$%&\'()*+,-./:;<=>?@[^]_`{|}~)'
passwordTooShortError = "Password too short: at least %d characters required."
minDigitsError = "Password doesn't meet minimum digits: at least %d required."
minLowerError = ( "Password doesn't meet minimum lowercase characters: "
                  "at least %d required." )
minSpecialError = ( "Password doesn't meet minimum special characters: at least %d"
                    " are required. The special characters are %s" )
minUpperError = ( "Password doesn't meet minimum uppercase characters: "
                  "at least %d required." )
minChangedError = ( "Password doesn't meet minimum changed characters: at least %d "
                    "are required." )
maxRepeatedError = ( "Password violated maximum repeated characters: "
                     "at most %d are allowed." )
maxSequentialError = ( "Password violated maximum sequential characters: "
                       "at most %d are allowed." )

denyUsernameError = ( "Password must be different from the username." )
denyLastPasswdError = ( "Password must be different with recent passwords recorded "
                        "by password history" )

def setDefaultHashFunc( func ):
   global defaultHashFunc
   defaultHashFunc = func

def setMinPasswordLengthFunc( func ):
   global minPasswordLengthFunc
   minPasswordLengthFunc = func

class SecretValue:
   """
   Stores a hashed password. If a cleartext pasword is
   provided SecretValue will validate the password across
   various criteria.
   """

   def __init__( self, mode, secretHash, secretCleartext=None, policyFn=None ):
      # secretErrors is a list of error conditions for the secret
      self.secretErrors = []
      self.mode = mode
      self.secretHash = secretHash
      self.newPass = None
      if secretCleartext:
         minPasswordLength = None
         if minPasswordLengthFunc:
            minPasswordLength = minPasswordLengthFunc()
         if policyFn and policyFn():
            policy = policyFn()
            if policy.minChanged != policy.defaultMinChanged:
               self.newPass = secretCleartext
            if policy.denyLastPasswd != policy.defaultDenyLastPasswd:
               self.newPass = secretCleartext
            minPasswordLength = policy.minPasswordLength
            digits, lower, special, upper, repeated, sequential = 0, 0, 0, 0, 1, 1
            for i, character in enumerate( secretCleartext ):
               if character.isdigit():
                  digits += 1
               elif character.islower():
                  lower += 1
               elif character in specialCharacters:
                  special += 1
               else:
                  upper += 1
               if i > 0:
                  if secretCleartext[ i-1 ] == character:
                     repeated += 1
                  elif repeated <= policy.maxRepeated:
                     repeated = 1
                  if ord( secretCleartext[ i-1 ] ) + 1 == ord( character ):
                     sequential += 1
                  elif sequential <= policy.maxSequential:
                     sequential = 1
            if digits < policy.minDigits:
               self.secretErrors.append( minDigitsError % policy.minDigits )
            if lower < policy.minLower:
               self.secretErrors.append( minLowerError % policy.minLower )
            if special < policy.minSpecial:
               self.secretErrors.append( minSpecialError %
                     ( policy.minSpecial, specialCharacters ) )
            if upper < policy.minUpper:
               self.secretErrors.append( minUpperError % policy.minUpper )
            if repeated > policy.maxRepeated:
               self.secretErrors.append( maxRepeatedError % policy.maxRepeated )
            if sequential > policy.maxSequential:
               self.secretErrors.append( maxSequentialError % policy.maxSequential )
         if minPasswordLength and len( secretCleartext ) < minPasswordLength:
            self.secretErrors.append( passwordTooShortError % minPasswordLength )

   def validateMinChangedCharacters( self, currPass, minChanged ):
      if not currPass or not self.newPass:
         return
      changed = 0
      for i in range( min( len( currPass ), len( self.newPass ) ) ):
         if self.newPass[ i ] != currPass[ i ]:
            changed += 1
      changed += abs( len( self.newPass ) - len( currPass ) )
      if changed < minChanged:
         self.secretErrors.append( minChangedError % minChanged )

   def validateDenyUsername( self, username ):
      encryptedUsername = encrypt( username, self.secretHash )
      if encryptedUsername == self.secretHash:
         self.secretErrors.append( denyUsernameError )

   def validateDenyLastPasswd( self, passwordHistory ):
      if not self.newPass:
         return False
      for oldPass in passwordHistory:
         if encrypt( self.newPass, oldPass ) == oldPass:
            self.secretErrors.append( denyLastPasswdError )
            return False
      return True

   def hash( self ):
      if self.secretErrors: # pylint: disable=no-else-raise
         for errorMsg in self.secretErrors:
            self.mode.addError( errorMsg )
         raise CliParser.AlreadyHandledError()
      else:
         return self.secretHash


######### Password helper functions ##########
hashAlgorithmsToIDs = {
   "5": "1", # md5
   "sha512": "6",
   "scrypt": "7",
   "yescrypt": "y"
}

def getAlgorithmInternalString( hashAlgorithm ):
   return hashAlgorithm if hashAlgorithm != "md5" else "5"

def getAlgorithmDisplayString( hashAlgorithm ):
   return hashAlgorithm if hashAlgorithm != "5" else "md5"

def encrypt( cleartextPasswd, settings=None, hashAlgorithm=None ):
   if not settings:
      hashAlgorithm = getAlgorithmInternalString( hashAlgorithm )
      assert hashAlgorithm in hashAlgorithmsToIDs, \
         "Algorithm given was not an implemented type"

      hashID = hashAlgorithmsToIDs[ hashAlgorithm ]

      settings = create_string_buffer( CRYPT_GENSALT_OUTPUT_SIZE )
      result = libcrypt.crypt_gensalt_rn( f"${ hashID }$".encode(), 0, None, 0,
                                          settings, CRYPT_GENSALT_OUTPUT_SIZE )
      assert result, "Failed generating settings (salt & parameters) for encryption"
   else:
      settings = settings.encode()

   cryptData = CRYPT_DATA()
   result = libcrypt.crypt_rn( cleartextPasswd.encode(), settings,
                               byref( cryptData ), sizeof( cryptData ) )
   assert result, "Failed encrypting cleartext secret"

   return cryptData.output.decode()

def setCleartextSecret( mode, secretString, hashAlgorithm=None, **kargs ):
   # Make sure we get the hash algorithm from Sysdb
   if not hashAlgorithm:
      hashAlgorithm = 'md5' if not defaultHashFunc else defaultHashFunc()

   policyFn = kargs.get( 'policyFn' )

   debug( f"{ getAlgorithmDisplayString( hashAlgorithm ) }-encrypting password" )
   return SecretValue( mode, encrypt( secretString, hashAlgorithm=hashAlgorithm ),
                       secretCleartext=secretString, policyFn=policyFn )


cleartextZeroMatcher = CliMatcher.PatternMatcher(
   cleartextPasswdRe,
   helpname='LINE',
   helpdesc='The UNENCRYPTED (cleartext) password' )

cleartextMatcher = CliMatcher.PatternMatcher(
   # Do not allow 0/5/sha512/scrypt/yescrypt* to avoid confusion
   r'(?!^([05\*]|sha512|scrypt|yescrypt)$)%s' % cleartextPasswdRe,
   helpname='LINE',
   helpdesc='The UNENCRYPTED (cleartext) password' )

md5Matcher = CliMatcher.PatternMatcher(
   md5EncryptedPasswdRe,
   helpname='LINE',
   helpdesc='The MD5 ENCRYPTED password' )

sha512Matcher = CliMatcher.PatternMatcher(
   sha512EncryptedPasswdRe,
   helpname='LINE',
   helpdesc='The SHA512 ENCRYPTED password' )

sha512KwMatcher = CliMatcher.KeywordMatcher(
   'sha512',
   helpdesc='Specify an ENCRYPTED SHA512 password will follow' )

scryptMatcher = CliMatcher.PatternMatcher(
   scryptEncryptedPasswdRe,
   helpname='LINE',
   helpdesc='The SCRYPT ENCRYPTED password' )

scryptKwMatcher = CliMatcher.KeywordMatcher(
   'scrypt',
   helpdesc='Specify an ENCRYPTED SCRYPT password will follow' )

yescryptMatcher = CliMatcher.PatternMatcher(
   yescryptEncryptedPasswdRe,
   helpname='LINE',
   helpdesc='The YESCRYPT ENCRYPTED password' )

yescryptKwMatcher = CliMatcher.KeywordMatcher(
   'yescrypt',
   helpdesc='Specify an ENCRYPTED YESCRYPT password will follow' )

def secretCliExpression( name, supportInvalidPassword=False,
                         extraClearTextExclude=None,
                         shaSecretGuard=None,
                         scryptSecretGuard=None,
                         yescryptSecretGuard=None,
                         policyFn=None ):
   """
   name: the name in args that holds the encrypted password
   supportInvalidPassword: whether include '*' to disable login
   extraClearTextExclude: clear text password supports white spaces, but it can
                          cause issue if the command supports additional options
                          after the secret. To work around this you can pass in
                          tokens that should not match additional cleatext tokens.
   shaSecretGuard: guard function for the SHA512 option.
   scryptSecretGuard: guard function for the SCRYPT option.
   yescryptSecretGuard: guard function for the YESCRYPT option.
   """
   if extraClearTextExclude:
      cleartextExtraMatcher = CliMatcher.PatternMatcher(
         r'(?!^({})$){}'.format( '|'.join( extraClearTextExclude ),
                             cleartextPasswdRe ),
         helpname='LINE',
         helpdesc='The UNENCRYPTED (cleartext) password' )
   else:
      cleartextExtraMatcher = cleartextZeroMatcher

   class SecretExpression( CliCommand.CliExpression ):
      expression = "( ( 0 $CLEARTEXT0 ) | $CLEARTEXT [ { $CLEARTEXT_EXT } ] ) " \
                   "| ( 5 $MD5 ) | ( sha512 $SHA512 ) | ( scrypt $SCRYPT )" \
                   "| ( yescrypt $YESCRYPT )"
      data = { "$CLEARTEXT0" : CliCommand.Node( cleartextZeroMatcher,
                                                sensitive=True ),
               "$CLEARTEXT" : CliCommand.Node( cleartextMatcher,
                                                sensitive=True ),
               "$CLEARTEXT_EXT" : CliCommand.Node( cleartextExtraMatcher,
                                                   sensitive=True ),
               '0' : 'Specify an UNENCRYPTED password will follow',
               '5' : 'Specify an ENCRYPTED MD5 password will follow',
               'sha512' : CliCommand.Node( sha512KwMatcher,
                                           guard=shaSecretGuard ),
               'scrypt': CliCommand.Node( scryptKwMatcher,
                              guard=scryptSecretGuard ),
               'yescrypt': CliCommand.Node( yescryptKwMatcher,
                              guard=yescryptSecretGuard ),
               '$MD5' : md5Matcher,
               '$SHA512': sha512Matcher,
               '$SCRYPT': scryptMatcher,
               '$YESCRYPT': yescryptMatcher }
      if supportInvalidPassword:
         expression += ' | *'
         data[ '*' ] = 'Specify a password that cannot be used to login'

      @staticmethod
      def adapter( mode, args, argsList ):
         def _getArg( argname ):
            # If the expression is used in a set, we may get an one-elem list
            val = args[ argname ]
            if isinstance( val, list ):
               assert len( val ) == 1
               val = val[ 0 ]
            return val

         secret = None
         if supportInvalidPassword and '*' in args:
            secret = SecretValue( mode, '*' )
            del args[ '*' ]
         elif '$CLEARTEXT0' in args or '$CLEARTEXT' in args:
            cleartext = args.pop( '$CLEARTEXT_EXT', [] )
            if len( cleartext ) == 1 and isinstance( cleartext[ 0 ], list ):
               # handle the case where SecretExpression is used in a set
               cleartext = cleartext[ 0 ]
            if '$CLEARTEXT0' in args:
               cleartext.insert( 0, _getArg( '$CLEARTEXT0' ) )
            else:
               cleartext.insert( 0, _getArg( '$CLEARTEXT' ) )

            for arg in ( '$CLEARTEXT0', '$CLEARTEXT', '$CLEARTEXT_EXT', '0' ):
               args.pop( arg, None )
            cleartext = ' '.join( cleartext )

            hashToUse = 'md5' if not defaultHashFunc else defaultHashFunc()

            shaGuarded = ( shaSecretGuard and shaSecretGuard( mode, '0' ) )
            scryptGuarded = ( scryptSecretGuard and scryptSecretGuard( mode, '0' ) )
            yescryptGuarded = ( yescryptSecretGuard and
                                yescryptSecretGuard( mode, '0' ) )

            if hashToUse == 'sha512' and shaGuarded:
               hashToUse = 'md5'
            elif ( hashToUse == 'scrypt' and scryptGuarded ) or \
                 ( hashToUse == 'yescrypt' and yescryptGuarded ):
               if shaGuarded:
                  hashToUse = 'md5'
               else:
                  hashToUse = 'sha512'

            secret = setCleartextSecret( mode, cleartext, hashAlgorithm=hashToUse,
                                          policyFn=policyFn )
         else:
            for alg in hashAlgorithmsToIDs:
               if alg in args:
                  algVariable = '$' + getAlgorithmDisplayString( alg ).upper()
                  secret = SecretValue( mode, _getArg( algVariable ) )
                  del args[ alg ]
                  del args[ algVariable ]
                  break

         if secret:
            args[ name ] = secret

   return SecretExpression

hashTypeFinder = re.compile( r'\$(\w+)\$.+' )

def getHashAlgorithm( encrypted ):
   hashIdToken = hashTypeFinder.search( encrypted )
   if not hashIdToken:
      return None

   hashIdToken = hashIdToken.group( 1 )
   for hashAlgorithm, hashId in hashAlgorithmsToIDs.items():
      if hashId == hashIdToken:
         return hashAlgorithm

   # Unknown hash id
   return None

def getCliSaveCommand( formatStr, encryptedPasswd ):
   # return running-config string for encryptedPasswd
   # None for error
   assert encryptedPasswd
   if encryptedPasswd == '*':
      return formatStr.format( '*' )
   else:
      hashAlgorithm = getHashAlgorithm( encryptedPasswd )
      if hashAlgorithm:
         formatStr = formatStr.format( f"{hashAlgorithm} {{}}" )
         return CliSaveBlock.SensitiveCommand( formatStr,
                                               encryptedPasswd )
      # Invalid hash, just add a comment
      return '! ' + formatStr.format( "<invalid>" )
