# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import re

import Arnet
import BasicCli
import BasicCliUtil
import CliCommand
import CliMatcher
from CliPlugin import Ssl
from CliPlugin import IpGenAddrMatcher
from CliPlugin import Security
import CliMode.Pki as PkiProfileMode
import CommonGuards
import SslCertKey
import Tac
import Tracing
import six

Constants = Tac.Type( 'Mgmt::Security::Ssl::Constants' )
Digest = Tac.Type( 'Mgmt::Security::Ssl::Digest' )
PublicKeyAlgo = Tac.Type( "Mgmt::Security::Ssl::PublicKeyAlgorithm" )

__defaultTraceHandle__ = Tracing.Handle( 'PkiCli' )
t0 = Tracing.trace0

pkiEnKwMatcher = CliMatcher.KeywordMatcher( 'pki',
   helpdesc='Configure PKI related options' )
keyKwMatcher = CliMatcher.KeywordMatcher( 'key', helpdesc='modify keys used' )
generateNode = CliCommand.Node(
      matcher=CliMatcher.KeywordMatcher( 'generate', helpdesc='create new item' ),
      guard=CommonGuards.standbyGuard,
      noResult=True )
rsaMatcher = CliMatcher.KeywordMatcher( 'rsa', helpdesc='Use RSA algorithm' )
validRsaKeySizes = [ 2048, 3072, 4096 ]
rsaKeySizeMap = { f'{bitLength}': f'Use {bitLength}-bit keys'
                  for bitLength in validRsaKeySizes }
rsaKeySizeMatcher = CliMatcher.EnumMatcher( rsaKeySizeMap )
ecdsaMatcher = CliMatcher.KeywordMatcher( 'ecdsa', helpdesc='Use ECDSA algorithm' )
ecCurveNameMap = { 'prime256v1':
                        'X9.62/SECG curve over a 256 bit prime field (P-256)',
                   'secp384r1':
                        'NIST/SECG curve over a 384 bit prime field (P-384)',
                   'secp521r1':
                        'NIST/SECG curve over a 521 bit prime field (P-521)' }
ecCurveNameMatcher = CliMatcher.EnumMatcher( ecCurveNameMap )
certKwMatcher = CliMatcher.KeywordMatcher( 'certificate',
      helpdesc='work with x509 certificate' )
signRequestKwMatcher = CliMatcher.KeywordMatcher( 'signing-request',
      helpdesc='Certificate Signing Request ( CSR )' )
digestKwMatcher = CliMatcher.KeywordMatcher( 'digest',
      helpdesc='Digest to sign with' )
digests = { Digest.sha256: 'Use 256 bit SHA',
            Digest.sha384: 'Use 384 bit SHA',
            Digest.sha512: 'Use 512 bit SHA' }
validityMatcher = CliMatcher.KeywordMatcher( 'validity',
      helpdesc='Validity of certificate' )
dnsNameRe = r'[0-9a-zA-Z_\.-]+'
# URI regex from RFC3986 Appendix B
uriAddrRe = r'^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?'
digestMatcher = CliMatcher.EnumMatcher( digests )

class PkiProfileConfigMode( PkiProfileMode.PkiProfileMode,
                            Ssl.ProfileConfigModeBase ):
   name = 'PKI profile configuration'

   def __init__( self, parent, session, profileName ):
      PkiProfileMode.PkiProfileMode.__init__( self, profileName )
      Ssl.ProfileConfigModeBase.__init__( self, parent, session, profileName,
                                          'profileTypePki' )

class GotoPkiProfileModeCmd( CliCommand.CliCommandClass ):
   syntax = 'pki profile PROFILE_NAME'
   noOrDefaultSyntax = syntax
   data = {
            'pki': pkiEnKwMatcher,
            'profile': Ssl.profileMatcher,
            'PROFILE_NAME': Ssl.profileNameMatcher
          }
   # The config command 'pki profile' is still hidden.
   # We are only giving enable mode 'pki' commands.
   hidden = True
   handler = "PkiHandler.GotoPkiProfileModeCmd_handler"
   noOrDefaultHandler = "PkiHandler.GotoPkiProfileModeCmd_noOrDefaultHandler"

Security.SecurityConfigMode.addCommandClass( GotoPkiProfileModeCmd )

#--------------------------------------------------
# [no|default] certificate CERT_NAME key KEY_NAME
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.CertificateCmd )

#--------------------------------------------------
# [no|default] trust certificate CERT_NAME
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.TrustCertificateCmd )
PkiProfileConfigMode.addCommandClass( Ssl.NoTrustCertificateCmd )

#--------------------------------------------------
# [no|default] chain certificate CERT_NAME
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.ChainCertificateCmd )
PkiProfileConfigMode.addCommandClass( Ssl.NoChainCertificateCmd )

#--------------------------------------------------
# [no|default] crl command
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.CrlCmd )
PkiProfileConfigMode.addCommandClass( Ssl.NoCrlCmd )

#--------------------------------------------------
# [no|default] revocation crl command
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.RevocationCrlCmd )
PkiProfileConfigMode.addCommandClass( Ssl.RevocationNoCrlCmd )

#--------------------------------------------------
# [no|default] certificate requirement hostname match
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyHostnameActionCmd )

#--------------------------------------------------------
# [no|default] certificate requirement extended-key-usage
#--------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyExtendedParametersCmd )

#----------------------------------------------------------------------------
# [no|default] [trust|chain] certificate requirement basic-constraint ca true
#----------------------------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyChainBasicConstraintTrustCmd )

# ----------------------------------------------------------------------------
# [no|default] [trust|chain] certificate requirement hostname fqdn
# ----------------------------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.VerifyTrustHostnameFqdnCmd )

#--------------------------------------------------
# [no|default] chain certificate requirement include root-ca
#--------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.ChainIncludeRootCACmd )

#---------------------------------------------------
# [no|default] certificate policy expiry-date ignore
#---------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.IgnoreExpiryDateCmd )

#---------------------------------------------------------
# [no|default] trust certificate policy expiry-date ignore
#---------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.IgnoreExpiryDateTrustCmd )

#---------------------------------------------------------
# [no|default] crl policy expiry-date ignore
#---------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.IgnoreExpiryDateCrlCmd )

#---------------------------------------------------------
# [no|default] Revocation crl policy expiry-date ignore
#---------------------------------------------------------
PkiProfileConfigMode.addCommandClass( Ssl.RevocationIgnoreExpiryDateCrlCmd )

class AlgorithmExpression( CliCommand.CliExpression ):
   expression = ( '( rsa RSA_BIT_LENGTH ) | '
                  '( ecdsa EC_CURVE_NAME )' )
   data = { 'rsa': rsaMatcher,
            'RSA_BIT_LENGTH': rsaKeySizeMatcher,
            'ecdsa': ecdsaMatcher,
            'EC_CURVE_NAME': ecCurveNameMatcher }

   @staticmethod
   def adapter( mode, args, argsList ):
      algoMap = { 'rsa': PublicKeyAlgo.RSA,
                  'ecdsa': PublicKeyAlgo.ECDSA }
      for algo, algoEnum in algoMap.items():
         if algo in args:
            args[ 'KEY_ALGO' ] = algoEnum
            return

#------------------------------------------------------------------------------------
# security pki key generate ( rsa RSA_BIT_LENGTH ) |
#                           ( ecdsa EC_CURVE_NAME ) KEY_NAME
#------------------------------------------------------------------------------------
class PkiGenerateKey( CliCommand.CliCommandClass ):
   syntax = 'security pki key generate ALGO_PARAMS KEY_NAME'
   data = {
            'security': Security.securityKwMatcher,
            'pki': pkiEnKwMatcher,
            'key': keyKwMatcher,
            'generate': generateNode,
            'ALGO_PARAMS': AlgorithmExpression,
            'KEY_NAME': Ssl.keyNameMatcher,
          }
   handler = "PkiHandler.PkiGenerateKey_handler"

BasicCli.EnableMode.addCommandClass( PkiGenerateKey )

#------------------------------------------------------------------------------------
# security pki certificate generate { signing-request | self-signed <cert-name> }
#             [ key <key-name> [ generate ( rsa <2048|3072|4096> ) |
#                                         ( ecdsa <prime256v1> ) ] ]
#             [ digest <sha256|sha384|sha512> ]
#             [ validity <days> ]
#             [ parameters common-name <common-name>
#                          [ country <country-code> ]
#                          [ state <state-name> ]
#                          [ locality <locality-name> ]
#                          [ organization <org-name> ]
#                          [ organization-unit <org-unit-name> ]
#                          [ email <email> ]
#                          [ subject-alternative-name [ ip <ip1 ip2 ..> ] 
#                                                     [ dns <name1 name2 ..> ] 
#                                                     [ email <em1 em2 ..> ] ] ]
#------------------------------------------------------------------------------------
class ParamExpression( CliCommand.CliExpression ):
   expression = ( 'parameters common-name COMMON_NAME'
                         '[ country COUNTRY_CODE ]'
                         '[ state STATE_NAME ]'
                         '[ locality LOCALITY_NAME ]'
                         '[ organization ORG_NAME ]'
                         '[ organization-unit ORG_UNIT_NAME ]'
                         '[ email EMAIL ]'
                         '[ subject-alternative-name { ( ip { IP } ) | '
                                                    '( dns { DNS } ) | '
                                                    '( EMAIL_KW { SAN_EMAIL } ) | '
                                                    '( uri { URI } ) } ]'
                         )
   data = {
            'parameters': 'Signing request parameters',
            'common-name': 'Common name for use in subject',
            'COMMON_NAME': CliMatcher.QuotedStringMatcher(),
            'country': 'Two-Letter Country Code for use in subject',
            'COUNTRY_CODE': CliMatcher.QuotedStringMatcher(),
            'state': 'State for use in subject',
            'STATE_NAME': CliMatcher.QuotedStringMatcher(),
            'locality': 'Locality Name for use in subject',
            'LOCALITY_NAME': CliMatcher.QuotedStringMatcher(),
            'organization': 'Organization Name for use in subject',
            'ORG_NAME': CliMatcher.QuotedStringMatcher(),
            'organization-unit': 'Organization Unit Name for use in subject',
            'ORG_UNIT_NAME': CliMatcher.QuotedStringMatcher(),
            'email': 'Email address for use in subject',
            'EMAIL': CliMatcher.PatternMatcher( r'\S+', helpname='WORD',
               helpdesc='Email address' ),
            'subject-alternative-name': 'Subject alternative name extension',
            'ip': CliCommand.singleKeyword( 'ip',
               helpdesc='IP addresses for use in subject-alternative-name' ),
            'IP': IpGenAddrMatcher.ipGenAddrMatcher,
            'dns': CliCommand.singleKeyword( 'dns',
               helpdesc='DNS names for use in subject-alternative-name' ),
            'DNS': CliMatcher.PatternMatcher( fr'^(?!email$|ip$)({dnsNameRe})',
               helpname='WORD', helpdesc='DNS name' ),
            'EMAIL_KW': CliCommand.singleKeyword( 'email',
               helpdesc='Email addresses for use in subject-alternative-name' ),
            'SAN_EMAIL': CliMatcher.PatternMatcher( r'^(?!dns$|ip$)(\S+)',
               helpname='WORD', helpdesc='Email address' ),
            'uri': CliCommand.singleKeyword( 'uri',
               helpdesc='URIs for use in subject-alternative-name' ),
            'URI': CliMatcher.PatternMatcher( uriAddrRe, helpname='WORD',
               helpdesc='URI string' ),
          }

   @staticmethod
   def adapter( mode, args, argsList ):
      if 'parameters' not in args or 'SIGN_REQ_PARAMS' in args:
         return
      result = {}
      translationMap = {
                        'COMMON_NAME': 'commonName',
                        'COUNTRY_CODE': 'country',
                        'STATE_NAME': 'state',
                        'LOCALITY_NAME': 'locality',
                        'ORG_NAME': 'orgName',
                        'ORG_UNIT_NAME': 'orgUnitName',
                        'EMAIL': 'emailAddress',
                        }
      for k, v in six.iteritems( translationMap ):
         result[ v ] = args.get( k )
      result[ 'san' ] = { 'sanIp': args.get( 'IP' ),
                          'sanDns': args.get( 'DNS' ),
                          'sanEmailAddress': args.get( 'SAN_EMAIL' ),
                          'sanUri': args.get( 'URI' ) }
      args[ 'SIGN_REQ_PARAMS' ] = result

# Taken from RFC5280, Upper Bounds ( page 123 of May 2008 rev. )
upperBounds = { 'country' : 2,
                'state' : 128,
                'locality' : 128,
                'orgName' : 64,
                'orgUnitName' : 64,
                'commonName' : 64,
                'emailAddress' : 128,
              }
 
def _validateParams( paramName, paramValue ):
   if not paramValue:
      return
   
   printName = { 'country' : 'Country code', 
                 'state' : 'State', 
                 'locality' : 'Locality', 
                 'orgName' : 'Organization name',
                 'orgUnitName' : 'Organization unit name', 
                 'commonName' : 'Common name', 
                 'emailAddress' : 'Email address' }
     
   if paramName in upperBounds:
      if len( paramValue ) > upperBounds[ paramName ]:
         raise SslCertKey.SslCertKeyError( f'{printName[ paramName ]} can be at '
            f'most {upperBounds[ paramName ]} characters.' )
     
   if paramName == 'sanIp':
      for ip in paramValue:
         try:
            Arnet.IpGenAddr( ip )
         except ( IndexError, ValueError ) as e:
            raise SslCertKey.SslCertKeyError( f'IP address \'{ip}\' is not a '
                                              'valid v4 or v6 address' ) from e
        
   if paramName == 'sanDns':
      for dnsName in paramValue:
         if not re.match( f'^{dnsNameRe}$', dnsName ):
            raise SslCertKey.SslCertKeyError( f'DNS name \'{dnsName}\''
                                              ' is not valid.' )
           
   if paramName == 'sanEmailAddress':
      for emailAddress in paramValue:
         if len( emailAddress ) > upperBounds[ 'emailAddress' ]:
            raise SslCertKey.SslCertKeyError( f'{printName[ "emailAddress" ]} '
               f'\'{emailAddress}\' can be at most '
               f'{upperBounds[ "emailAddress" ]} characters.' )

def _getSignRequestParamsInteractive( mode ):
   '''
   Takes in as single line inputs each attribute required for generating a
   Certificate Signing Request and checks for errors. If no errors are found,
   generate a CSR and output to terminal.
   '''
   signReqParams = {}
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Common Name for use in subject: ' )
   if not inp:
      raise SslCertKey.SslCertKeyError( 'Common Name is needed' )
   signReqParams[ 'commonName' ] = inp if inp else None
   _validateParams( 'commonName', signReqParams[ 'commonName' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Two-Letter Country Code for use in subject: ' )
   signReqParams[ 'country' ] = inp if inp else None
   _validateParams( 'country', signReqParams[ 'country' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'State for use in subject: ' )
   signReqParams[ 'state' ] = inp if inp else None
   _validateParams( 'state', signReqParams[ 'state' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Locality Name for use in subject: ' )
   signReqParams[ 'locality' ] = inp if inp else None
   _validateParams( 'locality', signReqParams[ 'locality' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Organization Name for use in subject: ' )
   signReqParams[ 'orgName' ] = inp if inp else None
   _validateParams( 'orgName', signReqParams[ 'orgName' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Organization Unit Name for use in subject: ' )
   signReqParams[ 'orgUnitName' ] = inp if inp else None
   _validateParams( 'orgUnitName', signReqParams[ 'orgUnitName' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
         'Email address for use in subject: ' )
   signReqParams[ 'emailAddress' ] = inp if inp else None
   _validateParams( 'emailAddress', signReqParams[ 'emailAddress' ] )
  
   inp = BasicCliUtil.getSingleLineInput( mode,
      'IP addresses (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanIp' ] = inp.split() if inp else None
   _validateParams( 'sanIp', signReqParams[ 'sanIp' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
      'DNS names (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanDns' ] = inp.split() if inp else None
   _validateParams( 'sanDns', signReqParams[ 'sanDns' ] )
   
   inp = BasicCliUtil.getSingleLineInput( mode,
      'Email addresses (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanEmailAddress' ] = inp.split() if inp else None
   _validateParams( 'sanEmailAddress', signReqParams[ 'sanEmailAddress' ] )

   inp = BasicCliUtil.getSingleLineInput( mode,
      'URIs (space separated) for use in subject-alternative-name: ' )
   signReqParams[ 'sanUri' ] = inp.split() if inp else None
   _validateParams( 'sanUri', signReqParams[ 'sanUri' ] )

   return signReqParams

def getSignRequestParams( mode, signReqParams ):
   if not signReqParams:
      signReqParams = _getSignRequestParamsInteractive( mode )
   else:
      sanParams = signReqParams.pop( 'san' )
      if sanParams:
         # Create a dict and fill in missing keys with None
         sanParams = dict( sanParams )
         sanParams[ 'sanIp' ] = sanParams.get( 'sanIp', None )
         sanParams[ 'sanDns' ] = sanParams.get( 'sanDns', None )
         sanParams[ 'sanEmailAddress' ] = sanParams.get( 'sanEmailAddress', None )
         sanParams[ 'sanUri' ] = sanParams.get( 'sanUri', None )
         if sanParams[ 'sanIp' ]:
            sanParams[ 'sanIp' ] = [ x.stringValue for x in sanParams[ 'sanIp' ] ]
      else:
         sanParams = { 'sanIp': None, 'sanDns': None,
                       'sanEmailAddress': None, 'sanUri': None }
      
      signReqParams.update( sanParams )
   
      for p in [ 'commonName', 'country', 'state', 'locality',
                 'orgName', 'orgUnitName', 'emailAddress',
                 'sanIp', 'sanDns', 'sanEmailAddress', 'sanUri' ]:
         _validateParams( p, signReqParams[ p ] )
   
   return signReqParams 

class PkiGenerateCsrOrCertCmd( CliCommand.CliCommandClass ):
   syntax = ( 'security pki certificate generate signing-request '
                                    '[ key KEY_NAME '
                                                '[ generate ALGO_PARAMS ] ] '
                                    '[ digest DIGEST ] '
                                    '[ SIGN_REQ_PARAMS ]' )
   data = {
            'security': Security.securityKwMatcher,
            'pki': pkiEnKwMatcher,
            'certificate': certKwMatcher,
            'generate': generateNode,
            'signing-request': signRequestKwMatcher,
            'key': keyKwMatcher,
            'KEY_NAME': Ssl.keyNameMatcher,
            'ALGO_PARAMS': AlgorithmExpression,
            'digest': digestKwMatcher,
            'DIGEST': digestMatcher,
            'SIGN_REQ_PARAMS': ParamExpression
          }
   handler = "PkiHandler.PkiGenerateCsrOrCert_handler"

BasicCli.EnableMode.addCommandClass( PkiGenerateCsrOrCertCmd )

class PkiGenerateSelfSignedCertCmd( CliCommand.CliCommandClass ):
   syntax = ( 'security pki certificate generate self-signed CERT_NAME '
                                    '[ key KEY_NAME '
                                                '[ generate ALGO_PARAMS ] ] '
                                    '[ digest DIGEST ] '
                                    '[ validity VALIDITY ] '
                                    '[ SIGN_REQ_PARAMS ]' )
   data = {
            'security': Security.securityKwMatcher,
            'pki': pkiEnKwMatcher,
            'certificate': certKwMatcher,
            'generate': generateNode,
            'self-signed': 'Self signed certificate',
            'CERT_NAME': Ssl.certificateNameMatcher,
            'key': keyKwMatcher,
            'KEY_NAME': Ssl.keyNameMatcher,
            'ALGO_PARAMS': AlgorithmExpression,
            'validity': validityMatcher,
            'VALIDITY': CliMatcher.IntegerMatcher( Constants.minCertValidity,
               Constants.maxCertValidity, helpdesc='Days' ),
            'digest': digestKwMatcher,
            'DIGEST': digestMatcher,
            'SIGN_REQ_PARAMS': ParamExpression
          }
   handler = "PkiHandler.PkiGenerateCsrOrCert_handler"

BasicCli.EnableMode.addCommandClass( PkiGenerateSelfSignedCertCmd )

