# Copyright (c) 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import CliSave, Tac
from CliSavePlugin import Management
import SysMgrLib
import SshCertLib
from IpLibConsts import DEFAULT_VRF
from CliMode.VrfConfig import VrfConfigMode
from CliMode.SshAuthenticationX509 import SshAuthenticationX509Mode
from CliMode.SshTunnel import SshTunnelMode
from CliMode.SshTunnel import SshTunnelVrfMode
from CliMode.SshUser import ( SshUserMode, SshKeyMode, SshPrincipalMode )
import os

from SysMgrLib import authenTacTypeToCliToken

serverPort = Tac.Type( "Mgmt::Ssh::ServerPort" )

AuthenMethod = Tac.Type( "Mgmt::Ssh::AuthenMethod" )

class SshConfigMode( Management.MgmtConfigMode ):

   def __init__( self, param ):
      Management.MgmtConfigMode.__init__( self, "ssh" )

class SshTunnelConfigMode( SshTunnelMode, CliSave.Mode ):

   def __init__( self, param ):
      SshTunnelMode.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

class SshVrfTunnelConfigMode( SshTunnelVrfMode, CliSave.Mode ):

   def __init__( self, param ):
      SshTunnelVrfMode.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( SshConfigMode )
SshConfigMode.addCommandSequence( 'Mgmt.ssh' )

SshConfigMode.addChildMode( SshTunnelConfigMode )
SshTunnelConfigMode.addCommandSequence( 'Mgmt.ssh.tunnel' )

class SshVrfConfigMode( VrfConfigMode, CliSave.Mode ):
   def __init__( self, param ):
      VrfConfigMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

SshConfigMode.addChildMode( SshVrfConfigMode )
SshVrfConfigMode.addCommandSequence( 'Mgmt.ssh.vrf' )
SshVrfConfigMode.addChildMode( SshVrfTunnelConfigMode, after=[ 'Mgmt.ssh.vrf' ] )
SshVrfTunnelConfigMode.addCommandSequence( 'Mgmt.ssh.vrf.tunnel' )

class SshUserConfigMode( SshUserMode, CliSave.Mode ):
   def __init__( self, param ):
      SshUserMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

SshConfigMode.addChildMode( SshUserConfigMode )
SshUserConfigMode.addCommandSequence( 'Mgmt.ssh.user' )

class SshKeyConfigMode( SshKeyMode, CliSave.Mode ):
   def __init__( self, param ):
      SshKeyMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

SshUserConfigMode.addChildMode( SshKeyConfigMode )
SshKeyConfigMode.addCommandSequence( 'Mgmt.ssh.user.ssh-key' )

class SshPrincipalConfigMode( SshPrincipalMode, CliSave.Mode ):
   def __init__( self, param ):
      SshPrincipalMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

SshUserConfigMode.addChildMode( SshPrincipalConfigMode )
SshPrincipalConfigMode.addCommandSequence( 'Mgmt.ssh.user.ssh-principal' )

class SshAuthenticationX509ConfigMode( SshAuthenticationX509Mode, CliSave.Mode ):
   def __init__( self, param ):
      SshAuthenticationX509Mode.__init__( self )
      CliSave.Mode.__init__( self, param )

SshConfigMode.addChildMode( SshAuthenticationX509ConfigMode )
SshAuthenticationX509ConfigMode.addCommandSequence( 'Mgmt.ssh.auth.x509' )

def saveKnownHosts( config, options, cmds, cliSaveMode=SshVrfConfigMode ):
   """
   Go through the config and create a list
   of known-host commands to be saved.
   """
   for host in config.knownHost:
      knownHostEntry = config.knownHost[ host ]
      ignoreMode = SshVrfConfigMode if knownHostEntry.configuredInSshMode \
                   else SshConfigMode
      if cliSaveMode == ignoreMode:
         continue
      keyType = SysMgrLib.tacKeyTypeToCliKey[ knownHostEntry.type ]
      cmds.addCommand( 'known-host %s %s %s' % ( knownHostEntry.host, keyType,
                                                 knownHostEntry.publicKey ) )

def saveSshTunnels( config, parentMode, tunnelMode, options, vrfName="" ):
   """
   Go through the config and generate a SSH tunnel sub-mode
   to save each tunnels info.
   """
   for tunnelName in config.tunnel:
      tunnel = config.tunnel[ tunnelName ]
      # Default VRF tunnel can be configured in 'management ssh' (old way) or in
      # 'vrf default' (new way).
      ignoreMode = SshVrfTunnelConfigMode if tunnel.configuredInSshMode \
                   else SshTunnelConfigMode
      if tunnelMode == ignoreMode:
         continue
      cmds = None

      if vrfName:
         mode = parentMode[ tunnelMode ].getOrCreateModeInstance(
               ( vrfName, tunnel.name ) )
         cmds = mode[ 'Mgmt.ssh.vrf.tunnel' ]
      else:
         mode = parentMode[ tunnelMode ].getOrCreateModeInstance( tunnel.name )
         cmds = mode[ 'Mgmt.ssh.tunnel' ]
      if ( tunnel.sshServerAddress ) and \
         ( tunnel.sshServerUsername ) and \
         ( tunnel.sshServerPort != serverPort.invalid ):
         cmds.addCommand( 'ssh-server %s user %s port %d' % \
                          ( tunnel.sshServerAddress,
                            tunnel.sshServerUsername,
                            tunnel.sshServerPort ) )
      if tunnel.localPort != serverPort.invalid:
         cmds.addCommand( 'local port %d' % ( tunnel.localPort ) )
      if ( tunnel.remoteHost ) and ( tunnel.remotePort != serverPort.invalid ):
         cmds.addCommand( 'remote host %s port %d' % ( tunnel.remoteHost,
                                                       tunnel.remotePort ) )
      saveServerAlive = tunnel.serverAliveInterval != \
            tunnel.serverAliveIntervalDefault or options.saveAll
      if saveServerAlive:
         cmds.addCommand( 'server-alive interval %d' %\
               ( tunnel.serverAliveInterval ) )
      saveServerMaxLost = tunnel.serverAliveMaxLost != \
            tunnel.serverAliveMaxLostDefault or options.saveAll
      if saveServerMaxLost:
         cmds.addCommand( 'server-alive count-max %d' %\
               ( tunnel.serverAliveMaxLost ) )
      if tunnel.unlimitedRestarts:
         cmds.addCommand( 'unlimited-restarts' )
      if tunnel.enable:
         cmds.addCommand( "no shutdown" )
      elif options.saveAll:
         cmds.addCommand( "shutdown" )

def saveSshUserConfig( config, parentMode, userMode, options, vrfName="" ):
   """
   Go through the config and generate a SSH user sub-mode
   to save each tunnels info.
   """
   for userName in config.user:
      userConfig = config.user[ userName ]
      # Render the SSH user config ONLY if user had entered it and not because
      # Aaa had created the user in SSH user collection to store SSH artifacts
      if userConfig.userModeEntered:
         mode = parentMode[ userMode ].getOrCreateModeInstance( userName )
         cmds = mode[ 'Mgmt.ssh.user' ]
         if userConfig.userTcpForwarding != userConfig.userTcpForwardingDefault:
            cmds.addCommand( "tcp forwarding %s" % userConfig.userTcpForwarding )
         elif options.saveAll:
            cmds.addCommand( "no tcp forwarding" )

def saveSshKeyConfig( userConfig, parentMode, options ):
   """
   Go through the config and generate the different modes and commands
   """
   for key in userConfig.sshAuthKeys:
      keyConfigMode = parentMode[ SshKeyConfigMode ].getOrCreateModeInstance(
         ( userConfig.name, key ) )
      cmds = keyConfigMode[ "Mgmt.ssh.user.ssh-key" ]
      keyConfig = userConfig.sshAuthKeys[ key ]
      if keyConfig.keyContents:
         cmds.addCommand( f"public-key {keyConfig.keyContents}" )

      if keyConfig.sshOptions:
         sortedOptions = sorted( keyConfig.sshOptions )
         for option in sortedOptions:
            if keyConfig.sshOptions[ option ].sshValues:
               sortedValues = sorted( keyConfig.sshOptions[ option ].sshValues )
               for value in sortedValues:
                  # If there was a space in the value, quote it
                  value = f"\"{value}\"" if " " in value else value
                  cmds.addCommand( f"option {option} {value}" )
            else:
               cmds.addCommand( f"option {option}" )

def saveSshPrincipalConfig( userConfig, parentMode, options ):
   """
   Go through the config and generate the different modes and commands
   """
   for principal in userConfig.sshAuthPrincipals:
      principalConfigMode = parentMode[
            SshPrincipalConfigMode ].getOrCreateModeInstance(
                  ( userConfig.name, principal ) )
      cmds = principalConfigMode[ "Mgmt.ssh.user.ssh-principal" ]
      principalConfig = userConfig.sshAuthPrincipals[ principal ]

      if principalConfig.sshOptions:
         sortedOptions = sorted( principalConfig.sshOptions )
         for option in sortedOptions:
            if principalConfig.sshOptions[ option ].sshValues:
               sortedValues = sorted(
                     principalConfig.sshOptions[ option ].sshValues )
               for value in sortedValues:
                  # If there was a space in the value, quote it
                  value = f"\"{value}\"" if " " in value else value
                  cmds.addCommand( f"option {option} {value}" )
            else:
               cmds.addCommand( f"option {option}" )

def saveSshAuthX509Config( config, parentMode, authenticationX509Mode, options,
                           vrfName="" ):
   """
   Save the 'authentication x509' mode settings
   """
   mode = parentMode[ authenticationX509Mode ].getSingletonInstance()
   cmds = mode[ 'Mgmt.ssh.auth.x509' ]
   if config.x509ProfileName:
      cmds.addCommand( f"server ssl profile {config.x509ProfileName}" )
   elif options.saveAll:
      cmds.addCommand( "no server ssl profile" )

   if config.x509UsernameDomainOmit:
      cmds.addCommand( "username domain omit" )
   elif options.saveAll:
      cmds.addCommand( "no username domain omit" )

def processHostKeys( configHostKeys ):
   processedKeys = []
   hostKeys = configHostKeys.split()
   for hostKey in hostKeys:
      if hostKey in SshCertLib.algoDirToSshAlgo:
         processedKeys.append( hostKey )
      else:
         algo = SshCertLib.getAlgoFromKeyPath( hostKey )
         fileName = os.path.basename( hostKey )
         processedKeys.append( "ssh-key:%s/%s" % ( algo, fileName ) )
   return " ".join( processedKeys )

@CliSave.saver( 'Mgmt::Ssh::Config', 'mgmt/ssh/config' )
def saveSsh( sshConfig, root, requireMounts, options ):
   mode = root[ SshConfigMode ].getSingletonInstance()
   cmds = mode[ 'Mgmt.ssh' ]

   if ( sshConfig.idleTimeout.timeout !=
        sshConfig.idleTimeout.defaultTimeout or options.saveAll ):
      # Only need to save timeout if different from default.
      cmds.addCommand( "idle-timeout %s" %
                       ( int( sshConfig.idleTimeout.timeout // 60 ) ) )

   authenticationMethods = set()
   multiFactorEnabled = False
   for methodList in sshConfig.authenticationMethodList.values():
      methods = methodList.method.values()
      if len( methods ) >= 2:
         multiFactorEnabled = True
      for method in methodList.method.values():
         authenticationMethods.add( method )

   passwordEnabled = AuthenMethod.password in authenticationMethods
   keyboardInteractiveEnabled = AuthenMethod.keyboardInteractive \
           in authenticationMethods
   publicKeyEnabled = AuthenMethod.publicKey in authenticationMethods

   if sshConfig.legacyAuthenticationModeSet and passwordEnabled:
      # Only use the legacy syntax when the command needs to be explicitly saved in
      # that form for backwards compatibility
      cmds.addCommand( 'authentication mode password' )
   else:
      defaultSettingsEnabled = ( not passwordEnabled
                                 and keyboardInteractiveEnabled
                                 and publicKeyEnabled
                                 and not multiFactorEnabled )
      if not defaultSettingsEnabled or options.saveAll:
         authenticationProtocolSaveArgs = [ 'authentication protocol' ]

         for methodList in sshConfig.authenticationMethodList.values():
            methods = methodList.method.values()
            if multiFactorEnabled:
               authenticationProtocolSaveArgs.append( 'multi-factor' )
            for method in methods:
               item = authenTacTypeToCliToken[ method ]
               authenticationProtocolSaveArgs.append( item )

         authenticationProtocolSaveCommand = ' '.join(
            authenticationProtocolSaveArgs )
         cmds.addCommand( authenticationProtocolSaveCommand )

   if sshConfig.serverPort != serverPort.defaultPort or options.saveAll:
      cmds.addCommand( 'server-port %d' % sshConfig.serverPort )

   if sshConfig.cipher != sshConfig.cipherDefault or options.saveAll:
      cmds.addCommand( 'cipher %s' % sshConfig.cipher )

   if sshConfig.kex != sshConfig.kexDefault or options.saveAll:
      cmds.addCommand( 'key-exchange %s' % sshConfig.kex )

   if sshConfig.mac != sshConfig.macDefault or options.saveAll:
      cmds.addCommand( 'mac %s' % sshConfig.mac )

   if ( sshConfig.rekeyDataAmount != sshConfig.rekeyDataAmountDefault or
        sshConfig.rekeyDataUnit != sshConfig.rekeyDataUnitDefault or
        options.saveAll ):
      cmds.addCommand( 'rekey frequency %d %s' % ( sshConfig.rekeyDataAmount,
                                                   sshConfig.rekeyDataUnit ) )
   if ( sshConfig.rekeyTimeLimit != sshConfig.rekeyTimeLimitDefault or
        sshConfig.rekeyTimeUnit != sshConfig.rekeyTimeUnitDefault or
        options.saveAll ):
      cmds.addCommand( 'rekey interval %d %s' % ( sshConfig.rekeyTimeLimit,
                                                  sshConfig.rekeyTimeUnit ) )

   if sshConfig.hostkey != sshConfig.hostkeyDefault or options.saveAll:
      cmds.addCommand( 'hostkey server %s' % processHostKeys( sshConfig.hostkey ) )

   if ( sshConfig.connLimit != sshConfig.connLimitDefault or
        options.saveAll ):
      cmds.addCommand( 'connection limit %s' % sshConfig.connLimit )

   if ( sshConfig.perHostConnLimit != sshConfig.perHostConnLimitDefault or
        options.saveAll ):
      cmds.addCommand( 'connection per-host %s' % sshConfig.perHostConnLimit )

   if sshConfig.fipsRestrictions:
      cmds.addCommand( 'fips restrictions' )
   elif options.saveAll:
      cmds.addCommand( 'no fips restrictions' )

   if sshConfig.enforceCheckHostKeys:
      cmds.addCommand( 'hostkey client strict-checking' )
   elif options.saveAll:
      cmds.addCommand( 'no hostkey client strict-checking' )

   val = sshConfig.permitEmptyPasswords
   if val != sshConfig.permitEmptyPasswordsDefault:
      cmds.addCommand( 'authentication empty-passwords %s' % val )
   elif options.saveAll:
      cmds.addCommand( 'authentication empty-passwords auto' )

   saveKnownHosts( sshConfig, options, cmds, cliSaveMode=SshConfigMode )

   val = sshConfig.clientAliveInterval
   if val != sshConfig.clientAliveIntervalDefault:
      cmds.addCommand( 'client-alive interval %s' % val )
   elif options.saveAll:
      cmds.addCommand( 'default client-alive interval' )

   val = sshConfig.clientAliveCountMax
   if val != sshConfig.clientAliveCountMaxDefault:
      cmds.addCommand( 'client-alive count-max %s' % val )
   elif options.saveAll:
      cmds.addCommand( 'default client-alive count-max' )

   if sshConfig.verifyDns:
      cmds.addCommand( 'verify dns' )
   elif options.saveAll:
      cmds.addCommand( 'no verify dns' )

   if sshConfig.serverState == "disabled":
      cmds.addCommand( 'shutdown' )
   elif options.saveAll:
      cmds.addCommand( 'no shutdown' )

   if ( sshConfig.successfulLoginTimeout.timeout !=
        sshConfig.successfulLoginTimeout.defaultTimeout or options.saveAll ):
      # Only need to save timeout if different from default.
      time = sshConfig.successfulLoginTimeout.timeout
      if time == 0:
         cmds.addCommand( "no login timeout" )
      else:
         cmds.addCommand( "login timeout %d" % ( time, ) )

   if sshConfig.logLevel != sshConfig.logLevelDefault or options.saveAll:
      cmds.addCommand( "log-level %s" % sshConfig.logLevel )

   if sshConfig.loggingTargetEnabled:
      cmds.addCommand( "logging target system" )
   elif options.saveAll:
      cmds.addCommand( "no logging target system" )

   if sshConfig.dscpValue != sshConfig.dscpValueDefault:
      cmds.addCommand( "qos dscp %s" % sshConfig.dscpValue )
   elif options.saveAll:
      cmds.addCommand( "qos dscp %s" % sshConfig.dscpValueDefault )

   if len( sshConfig.tunnel ) > 0:
      saveSshTunnels( sshConfig, mode, SshTunnelConfigMode, options )

   if sshConfig.authPrincipalsCmdFile:
      cmds.addCommand( 'authorized-principals command %s %s'
                       % ( sshConfig.authPrincipalsCmdFile,
                           sshConfig.authPrincipalsCmdArgs ) )
   elif options.saveAll:
      cmds.addCommand( 'no authorized-principals command' )

   if sshConfig.caKeyFiles:
      # need to sort manually until TACC supports ordered sets
      caKeyFiles = " ".join( sorted( sshConfig.caKeyFiles ) )
      cmds.addCommand( "trusted-ca key public %s" % caKeyFiles )
   elif options.saveAll:
      cmds.addCommand( "no trusted-ca key public" )

   if sshConfig.hostCertFiles:
      # need to sort manually until TACC supports ordered sets
      hostCertFiles = " ".join( sorted( sshConfig.hostCertFiles ) )
      cmds.addCommand( "hostkey server cert %s" % hostCertFiles )
   elif options.saveAll:
      cmds.addCommand( "no hostkey server cert" )

   if sshConfig.revokedUserKeysFiles:
      # need to sort manually until TACC supports ordered sets
      revokedUserKeysFiles = " ".join( sorted( sshConfig.revokedUserKeysFiles ) )
      cmds.addCommand( "user-keys revoke-list %s" % revokedUserKeysFiles )
   elif options.saveAll:
      cmds.addCommand( "no user-keys revoke-list" )

   if sshConfig.user:
      saveSshUserConfig( sshConfig, mode, SshUserConfigMode, options )

      # Check if we need to render SSH key config in the new CLI
      if sshConfig.useNewSshCliKey:
         for user in sshConfig.user:
            # Store in the new CLI
            userMode = mode[ SshUserConfigMode ].getOrCreateModeInstance( user )
            saveSshKeyConfig( sshConfig.user[ user ], userMode, options )

      # Check if we need to render SSH principal config in the new CLI
      if sshConfig.useNewSshCliPrincipal:
         for user in sshConfig.user:
            # Store in the new CLI
            userMode = mode[ SshUserConfigMode ].getOrCreateModeInstance( user )
            saveSshPrincipalConfig( sshConfig.user[ user ], userMode, options )

   saveSshAuthX509Config( sshConfig, mode, SshAuthenticationX509ConfigMode, options )

   for vrfName, vrfConfig in sshConfig.vrfConfig.items():
      vrfMode = mode[ SshVrfConfigMode ].getOrCreateModeInstance(
         ( vrfName, 'ssh', sshConfig ) )
      vrfCmds = vrfMode[ 'Mgmt.ssh.vrf' ]
      # pylint thinks 'cmds' is a list
      # pylint: disable-msg=E1103
      if vrfConfig.serverState == "enabled":
         vrfCmds.addCommand( "no shutdown" )
      elif vrfConfig.serverState == "disabled":
         vrfCmds.addCommand( "shutdown" )
      elif options.saveAll and vrfConfig.serverState == "globalDefault":
         vrfCmds.addCommand( "default shutdown" )
      conf = sshConfig if vrfName == DEFAULT_VRF else vrfConfig
      saveKnownHosts( conf, options, vrfCmds )
      saveSshTunnels( conf, vrfMode, SshVrfTunnelConfigMode, options,
                      vrfName=vrfName )

@CliSave.saver( 'Acl::Input::CpConfig', 'acl/cpconfig/cli' )
def saveSshIpAclRev1( aclCpConfig, root, requireMounts, options ):
   def saveServiceAcl( aclType ):
      for vrfName, serviceAclVrfConfig in \
          aclCpConfig.cpConfig[ aclType ].serviceAcl.items():
         serviceConfig = serviceAclVrfConfig.service.get( 'ssh' )
         if serviceConfig:
            if serviceConfig.aclName != '':
               mode = root[ SshConfigMode ].getSingletonInstance()
               cmds = mode[ 'Mgmt.ssh' ]
               if vrfName == DEFAULT_VRF:
                  cmds.addCommand( '%s access-group %s in' %
                                    ( aclType, serviceConfig.aclName ) )
               else:
                  cmds.addCommand( '%s access-group %s vrf %s in' %
                                    ( aclType, serviceConfig.aclName, vrfName ) )
            elif options.saveAll and serviceConfig.defaultAclName != '':
               mode = root[ SshConfigMode ].getSingletonInstance()
               cmds = mode[ 'Mgmt.ssh' ]
               if vrfName == DEFAULT_VRF:
                  cmds.addCommand( '%s access-group %s in' %
                                    ( aclType, serviceConfig.defaultAclName ) )
               else:
                  cmds.addCommand( '%s access-group %s vrf %s in' %
                                    ( aclType, serviceConfig.defaultAclName,
                                      vrfName ) )

   for t in ( 'ip', 'ipv6' ):
      saveServiceAcl( t )
