#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from calendar import timegm
import itertools
import datetime
from time import mktime, strftime
import Toggles.MgmtSecurityToggleLib as MgmtSecurityToggle
import BasicCli
import ShowCommand
import CliCommand
import CliMatcher
import CliParser
import ConfigMount
import DateTimeRule
import LazyMount
import Tac
from TypeFuture import TacLazyType
from CliPlugin import ConfigMgmtMode
from CliPlugin.Security import SecurityConfigMode, registerSecurityCleanupCallback
from CliPlugin.SharedSecretProfileModel import ( SharedSecretProfiles,
                                                 SharedSecretProfile, SecretModel )
from CliMode.SharedSecretProfileMode import SharedSecretProfileCliMode
from MgmtSecurityLib import mgmtSecurityConfigType
from SharedSecretProfileLib import ( infiniteLifetime, infiniteLifetimeEnd, Secret,
                                     isInfiniteEnd )
import ReversibleSecretCli

config = None
status = None
securityConfig = None
SecretConfigLocation = TacLazyType(
   'Mgmt::Security::SharedSecretProfile::SecretConfigLocation' )

secretConfigLocations = {
   'memory-only': SecretConfigLocation.memoryOnly,
   'encrypted-config': SecretConfigLocation.encryptedConfig,
   'running-config': SecretConfigLocation.runningConfig,
}

securityKwMatcher = CliMatcher.KeywordMatcher( 'security',
                                               helpdesc="Show security status" )

class SharedSecretProfileMode( SharedSecretProfileCliMode, BasicCli.ConfigModeBase ):
   name = "Shared-Secret Profile Configuration"

   def __init__( self, parent, session, profileName ):
      self.profileName = profileName
      self.session_ = session

      if self.profileName not in config.profile:
         self.profileConfig = config.newProfile( self.profileName )
      else:
         self.profileConfig = config.profile[ self.profileName ]

      SharedSecretProfileCliMode.__init__( self, self.profileName )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

# ------------------------------------------------------------
# [no|default] session shared-secret profile PROFILE_NAME
# ------------------------------------------------------------

class SharedSecretProfileModeCommandClass( CliCommand.CliCommandClass ):
   """Shared Secret Profile mode commands

      From management security mode do
         session shared-secret profile NAME
      to enter or create shared secret profile configuration mode
      for the specified NAME


      From management security mode
         ( no | default ) session shared-secret profile NAME
      to unconfigure all secret lifetimes configured in
      the shared secret profile configuration mode for the specified NAME"""

   syntax = "session shared-secret profile PROFILE_NAME"
   noOrDefaultSyntax = syntax
   data = {
      'session': 'Configure session settings',
      'shared-secret': 'Configure settings involving a shared secret',
      'profile': 'Configure a profile of shared secret lifetimes',
      'PROFILE_NAME': CliMatcher.DynamicNameMatcher(
         lambda mode: config.profile,
         'shared-secret profile name' )
   }

   @staticmethod
   def handler( mode, args ):
      childMode = mode.childMode( SharedSecretProfileMode,
                                  profileName=args[ 'PROFILE_NAME' ] )
      mode.session_.gotoChildMode( childMode )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      profileName = args[ 'PROFILE_NAME' ]
      if profileName in config.profile:
         del config.profile[ profileName ]

SecurityConfigMode.addCommandClass( SharedSecretProfileModeCommandClass )

def noSecurityConfig():
   """does a no or default to every profile
      for situations such as when the parent mode is deleted"""
   config.profile.clear()
   config.secretLocation = config.defaultSecretLocation

registerSecurityCleanupCallback( noSecurityConfig )

# ----------------------------------------------------------------
# session shared-secret hidden ( memory-only | encrypted-config )
# ----------------------------------------------------------------

class SharedSecretHiddenCmd( CliCommand.CliCommandClass ):
   syntax = "session shared-secret hidden ( memory-only | encrypted-config )"
   noOrDefaultSyntax = "session shared-secret hidden ..."

   data = {
      'session': 'Configure session settings',
      'shared-secret': 'Configure settings involving a shared secret',
      'hidden': 'Hide secrets, not showing them in the configuration',
      'memory-only': 'Do not persist secrets, keep them only in volatile memory',
      'encrypted-config': 'Persist secrets in the encrypted-config',
   }

   @staticmethod
   def handler( mode, args ):
      location = args.get( 'memory-only' ) or args.get( 'encrypted-config' )
      config.secretLocation = secretConfigLocations[ location ]

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      config.secretLocation = config.defaultSecretLocation

if MgmtSecurityToggle.toggleHideSecretsEnabled():
   SecurityConfigMode.addCommandClass( SharedSecretHiddenCmd )

# ------------------------------------------------------------
# [no|default] secret ID SECRET LIFETIME ...
# ------------------------------------------------------------

timeMatcher = DateTimeRule.ValidTimeMatcher()

def dateRuleValue( mode, date ):
   """converts a valid date of the form [ M, D, Y ]
      to a corresponding dictionary"""
   return { 'year': date[ 2 ], 'month': date[ 0 ], 'day': date[ 1 ] }

def timeRuleValue( mode, time ):
   """converts a time of the form ( h, m, s ) to a corresponding dictionary"""
   return { 'hour': time[ 0 ],
            'minute': time[ 1 ],
            'second': time[ 2 ] }

def datetimeRuleValue( mode, date, time ):
   """passes a time tuple to timegm, which computes seconds since the epoch"""
   value = None
   try:
      # Check date validity
      # datetime.date can raise ValueError for *some* date issues
      datetuple = ( date[ 'year' ], date[ 'month' ], date[ 'day' ] )
      userdate = datetime.date( *datetuple )
      if ( userdate.year, userdate.month, userdate.day ) != datetuple:
         raise CliParser.InvalidInputError()
      # timegm can also raise ValueError
      value = int( timegm( ( date[ 'year' ], date[ 'month' ], date[ 'day' ],
                             time[ 'hour' ], time[ 'minute' ], time[ 'second' ] ) ) )
   except ValueError:
      raise CliParser.InvalidInputError()
   outOfBounds = False
   try:
      outOfBounds = ( value is None or value < 0 or isInfiniteEnd( value ) )
   except TypeError:
      # If value is very large it may raise a TypeError when passed to isInfiniteEnd.
      outOfBounds = True
   if outOfBounds:
      mode.addError( 'Date must be after 1970 and before 10000' )
      raise CliParser.AlreadyHandledError()
   return value

def lifetimeRuleValue( mode, start, end ):
   """converts a start and end datetime (seconds past the epoch) to a lifetime
   """
   if start < end:
      return Tac.Value( 'Mgmt::Security::SharedSecretProfile::Lifetime', start, end )
   else:
      mode.addError( 'lifetime must end after it starts' )
      raise CliParser.AlreadyHandledError()

def lifetimeExpression( name ):
   startDateKey = 'START_DATE_' + name
   endDateKey = 'END_DATE_' + name
   startTimeKey = 'START_TIME_' + name
   endTimeKey = 'END_TIME_' + name
   infKey = 'infinite_' + name

   class LifetimeExpression( CliCommand.CliExpression ):
      expression = "( %s %s %s %s ) | ( %s %s %s ) | %s" % (
         startDateKey, startTimeKey, endDateKey, endTimeKey,
         startDateKey, startTimeKey, infKey,
         infKey )
      data = {
         startDateKey: DateTimeRule.dateExpression( startDateKey ),
         startTimeKey: timeMatcher,
         endDateKey: DateTimeRule.dateExpression( endDateKey ),
         endTimeKey: timeMatcher,
         infKey: CliMatcher.KeywordMatcher( 'infinite',
                                             helpdesc='infinite lifetime' )
      }

      @staticmethod
      def adapter( mode, args, argsList ):
         infinite = args.get( infKey )
         lifetime = None
         if startDateKey in args:
            startDate = args[ startDateKey ]
            startTime = args[ startTimeKey ]
            start = datetimeRuleValue( mode,
                                       dateRuleValue( mode, startDate ),
                                       timeRuleValue( mode, startTime ) )
            if infinite:
               lifetime = infiniteLifetimeEnd( start )
            else:
               endDate = args[ endDateKey ]
               endTime = args[ endTimeKey ]
               end = datetimeRuleValue( mode,
                                        dateRuleValue( mode, endDate ),
                                        timeRuleValue( mode, endTime ) )
               lifetime = lifetimeRuleValue( mode, start, end )
         elif infinite:
            lifetime = infiniteLifetime
         if lifetime is not None:
            args[ name ] = lifetime

   return LifetimeExpression

def secretIdType():
   return CliMatcher.DynamicNameMatcher(
      lambda mode: config.profile[ mode.profileName ].secret,
      'Identifier for the key' )

secretExpression = ReversibleSecretCli.ReversiblePasswordCliExpression(
                     cleartextMatcher=ReversibleSecretCli.cleartextAuthMatcher,
                     obfuscatedTextMatcher=ReversibleSecretCli.type7AuthMatcher,
                     algorithm='DES', errorMsg='Invalid encrypted password' )

def getEncryptedPassword( secret ):
   algo = 'AES-256-GCM' if securityConfig.commonEncrytionAESEnabled else 'DES'
   return ReversibleSecretCli.encodePassword( secret, algorithm=algo,
                                              securityConfig=securityConfig )

def convertToUtc( mode, localLifetime ):
   """converts the Lifetime object from local timezone to Utc"""
   if localLifetime.isInfinite():
      return infiniteLifetime
   try:
      startDatetime = datetime.datetime.utcfromtimestamp( localLifetime.start )
      # The timezone conversion can return negative values when the time
      # is configured near epoch, but the Lifetime object is U64. Thus, set
      # start time to 0 instead.
      startSecondUtc = max( 0, int( mktime( startDatetime.timetuple() ) ) )

      if isInfiniteEnd( localLifetime.end ):
         return infiniteLifetimeEnd( startSecondUtc )

      endDateTime = datetime.datetime.utcfromtimestamp( localLifetime.end )
      # End time cannot be negative or 0. Thus, set the end time to 1 instead
      # when the timezone conversion returns negative values.
      endSecondUtc = max( 1, int( mktime( endDateTime.timetuple() ) ) )

   except OverflowError:
      mode.addError( 'Date in local time must be after 1970 and before 2038' )
      raise CliParser.AlreadyHandledError()

   return Tac.Value( 'Mgmt::Security::SharedSecretProfile::Lifetime',
                      startSecondUtc, endSecondUtc )

class SecretConfigCommand( CliCommand.CliCommandClass ):
   syntax = "secret ID SECRET LIFETIME | " \
            "( receive-lifetime RECV_LIFETIME " \
            "transmit-lifetime TRANS_LIFETIME ) | " \
            "( transmit-lifetime TRANS_LIFETIME " \
            "receive-lifetime RECV_LIFETIME ) " \
            "[ local-time ]"
   noOrDefaultSyntax = "secret ID ..."

   data = {
      'secret': 'Configure lifetimes for a specified key',
      'receive-lifetime': 'Configure the lifetime for receiving the key',
      'transmit-lifetime': 'Configure the lifetime for transmitting the key',
      'ID': secretIdType(),
      'SECRET': secretExpression,
      'LIFETIME': lifetimeExpression( 'LIFETIME' ),
      'RECV_LIFETIME': lifetimeExpression( 'RECV_LIFETIME' ),
      'TRANS_LIFETIME': lifetimeExpression( 'TRANS_LIFETIME' ),
      'local-time': 'Configuring secrets using the local timezone '
                    'from system clock. Default is UTC'
   }

   @staticmethod
   def handler( mode, args ):
      keyId = str( args[ 'ID' ] )
      recvLifetime = args.get( 'LIFETIME', args.get( 'RECV_LIFETIME' ) )
      transLifetime = args.get( 'LIFETIME', args.get( 'TRANS_LIFETIME' ) )
      secret = args[ 'SECRET' ]
      assert recvLifetime is not None
      assert transLifetime is not None

      if 'local-time' in args:
         recvLifetime = convertToUtc( mode, recvLifetime )
         transLifetime = convertToUtc( mode, transLifetime )

      # Reuse the previous ReversibleSecret if the actual clearText
      # unchange and avoid randomized salt in the secret unwantedly
      # triggering downstream attribute change reactors.
      oldSecret = mode.profileConfig.secret.get( keyId )
      if oldSecret and oldSecret.secret.clearTextEqual( secret.getClearText() ):
         secret = oldSecret.secret
      mode.profileConfig.secret.addMember( Secret( keyId,
                                                   secret,
                                                   recvLifetime,
                                                   transLifetime ) )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      del mode.profileConfig.secret[ str( args[ 'ID' ] ) ]

SharedSecretProfileMode.addCommandClass( SecretConfigCommand )

# ------------------------------------------------------------
# show management security session shared-secret profile ...
# ------------------------------------------------------------

class ShowSharedSecretProfileMixin:
   @staticmethod
   def _rotationSequence( getLifetime, secrets ):
      """
      A given profile's secrets will become "current" in a deterministic sequence
      based on their respective lifetimes. For example, a trivial example is 3
      secrets with disjoint lifetimes:

      Secret 1: now => for 3 months
      Secret 2: 4 months from now => for 3 more months
      Secret 3: 8 months from now => for 3 more months

      These secrets obviously become valid in the order { 1, 2, 3 }. In general,
      the order in which the secrets will be chosen is determined by their start
      times.

      A naive algorithm would simply iterate over all possible times representable
      as an integer [0,2^64-1), determine the set of valid secrets at each time
      instance. Then, for each set determine the most preferred of the valid
      secrets. Compressing the resulting list would yield the rotation order.
      Obviously, this is awfully inefficient. We can significantly improve the
      algorithm by only considering times at which the secret ID can change: the
      start and end of lifetimes. This routine does just that.
      """
      def significantTimes():
         times = set()
         for s in secrets:
            lifetime = getLifetime( s )
            times.add( lifetime.start )
            times.add( lifetime.end )
         return sorted( times )

      def mostPreferredSecret( time ):
         """
         Determine the list of secrets which are valid at time 'time'. From this,
         return the "most preferred". That is, the secret whose lifetime began
         most recently. Infinite duration secrets are only chosen if there are no
         other valid secrets. Ties go to the higher ID.
         """
         def infinite( s ):
            return getLifetime( s ).isInfinite()
         validSecrets = [ s for s in secrets
                          if getLifetime( s ).isValidAtTime( time ) ]
         if not validSecrets:
            return "No Secret"

         nonInfiniteSecrets = [ s for s in validSecrets if not infinite( s ) ]
         if nonInfiniteSecrets:
            # First non-infinite secret as sorted by tac's Secret::operator<
            secretId = min( nonInfiniteSecrets ).id
         else:
            # First infinite secret as sorted by tac's Secret::operator<
            secretId = min( validSecrets ).id
         return secretId

      sequence = map( mostPreferredSecret, significantTimes() )

      # We may have the same secret multiple times, e.g. { 5, 5, 10, 15, 15, 15,
      # No Secret }. Remove the duplicates. Also, convert to strings.
      sequence = [ str( secretId ) for secretId, _ in itertools.groupby( sequence ) ]

      # Unless there is an infinite duration secret or one which starts at time 0
      # (1970), there will always be a "No Secret" at the front. This information
      # is rather useless, so trim it.
      if sequence and sequence[ 0 ] == "No Secret":
         sequence = sequence[ 1 : ]

      return sequence

   @staticmethod
   def _secretID( secret ):
      try:
         return int( secret.id )
      except ValueError:
         return 0

   @staticmethod
   def _computeSharedSecretProfileModel( profile, localTimeToken=False ):
      model = SharedSecretProfile()
      name = profile.name
      currentSecret = status.currentSecret.get( name )
      if currentSecret is None:
         return None
      preferredRxSecret = currentSecret.preferredReceiveSecret()
      rxSecret = preferredRxSecret if preferredRxSecret else Secret()
      txSecret = currentSecret.transmitSecretStatus

      model.profileName = name
      if rxSecret:
         model.currentRxId = ShowSharedSecretProfileMixin._secretID( rxSecret )
         model.rxId = rxSecret.id
         model.currentRxExpiration = rxSecret.receiveLifetime.end
      if txSecret:
         model.currentTxId = ShowSharedSecretProfileMixin._secretID( txSecret )
         model.txId = txSecret.id
         model.currentTxExpiration = txSecret.transmitLifetime.end

      secrets = list( profile.secret.values() )
      # pylint: disable=protected-access
      model.rxRotation = ShowSharedSecretProfileCmd._rotationSequence(
            lambda s: s.receiveLifetime, secrets )
      model.txRotation = ShowSharedSecretProfileCmd._rotationSequence(
            lambda s: s.transmitLifetime, secrets )
      # pylint: enable=protected-access

      model.secrets = []
      for s in sorted( secrets ):
         secret = SecretModel()

         secret.id = s.id
         secret.secretID = ShowSharedSecretProfileMixin._secretID( s )
         secret.secret = getEncryptedPassword( s.secret )
         secret.rxLifetimeStart = s.receiveLifetime.start
         secret.rxLifetimeEnd = s.receiveLifetime.end
         secret.txLifetimeStart = s.transmitLifetime.start
         secret.txLifetimeEnd = s.transmitLifetime.end
         secret.timezone = strftime( '%Z' ) if localTimeToken else 'UTC'

         model.secrets.append( secret )

      return model

class ShowSharedSecretProfilesCmd( ShowCommand.ShowCliCommandClass,
                                   ShowSharedSecretProfileMixin ):
   syntax = "show management security session shared-secret profile [ local-time ]"
   data = {
      'management': ConfigMgmtMode.managementShowKwMatcher,
      'security': securityKwMatcher,
      'session': 'Show security session information',
      'shared-secret': 'Show shared-secrets',
      'profile': 'Show a shared-secret profile configuration',
      'local-time': 'Show secret lifetime in local timezone. Default is in UTC'
   }
   cliModel = SharedSecretProfiles

   @classmethod
   def handler( cls, mode, args ):
      model = cls.cliModel()
      for profile in config.profile.values():
         localTime = 'local-time' in args
         computedProfileModel = cls._computeSharedSecretProfileModel( profile,
                                                                      localTime )
         if computedProfileModel:
            model.profiles[ profile.name ] = computedProfileModel
      return model

BasicCli.addShowCommandClass( ShowSharedSecretProfilesCmd )

class ShowSharedSecretProfileCmd( ShowCommand.ShowCliCommandClass,
                                  ShowSharedSecretProfileMixin ):
   syntax = "show management security session shared-secret profile NAME  " \
            "[ local-time ]"
   data = {
      'management': ConfigMgmtMode.managementShowKwMatcher,
      'security': securityKwMatcher,
      'session': 'Show security session information',
      'shared-secret': 'Show shared-secrets',
      'profile': 'Show a shared-secret profile configuration',
      'NAME': CliMatcher.DynamicNameMatcher( lambda mode: config.profile,
                                            'shared-secret profile name' ),
      'local-time': 'Show secret lifetime in local timezone. Default is in UTC'
   }
   cliModel = SharedSecretProfile

   @classmethod
   def handler( cls, mode, args ):
      name = args[ 'NAME' ]
      profile = config.profile.get( name )
      if profile:
         localTime = 'local-time' in args
         return cls._computeSharedSecretProfileModel( profile, localTime )
      return cls.cliModel()

BasicCli.addShowCommandClass( ShowSharedSecretProfileCmd )

def Plugin( entityManager ):
   global config, status
   global securityConfig
   config = ConfigMount.mount( entityManager,
                               "mgmt/security/sh-sec-prof/config",
                               "Mgmt::Security::SharedSecretProfile::Config", "w" )
   status = LazyMount.mount( entityManager,
                             "mgmt/security/sh-sec-prof/status",
                             "Mgmt::Security::SharedSecretProfile::Status", "r" )
   securityConfig = LazyMount.mount( entityManager, "mgmt/security/config",
                                     mgmtSecurityConfigType, "r" )
