# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string
# pylint: disable=chained-comparison

import Tac
import BasicCli
import BasicCliUtil
import BasicCliModes
import CliCommand
import CliMatcher
import LazyMount
import Url
import CliPlugin.TechSupportCli
from CliPlugin.LicenseStatusModel import FeatureModel
from CliPlugin.LicenseStatusModel import LicenseFeatureModel
from CliPlugin.LicenseStatusModel import LicenseFilesModel
from CliPlugin.LicenseStatusModel import LicenseInstance
from CliPlugin.LicenseStatusModel import LicenseStatusModel
import CliToken.Clear
import ShowCommand
from datetime import datetime
import os
import re
import json
import tempfile
import zlib
import base64

licConsts = Tac.Type( "License::LicenseConstants" )

licConfig = None
licStatus = None
hwStatus = None
entMib = None
licenseInput = None
hwOliveStatusSliceDir = None

# Map for human readable description of license feature
featureDesc = { 'CloudEOS' : 'Virtualized EOS' }

# This CLI plugin defines the following commands
#    sw# license import <URL>
#    sw# license update
#    sw# show license [ expired | all ]
#    sw# show license files [ compressed ]

#--------------------------------------------------------------
# Tokens
#--------------------------------------------------------------
matcherLicenseConfig = CliMatcher.KeywordMatcher( 'license',
      helpdesc='License configuration' )
matcherLicenseShow = CliMatcher.KeywordMatcher( 'license', helpdesc='Show licenses' )

#--------------------------------------------------------------
# Utilities
#--------------------------------------------------------------
def readJson( filename ):
   try:
      with open( filename ) as f:
         parsedJson = json.loads( f.read() )
      return parsedJson
   except ( ValueError, OSError ):
      return None

def isFileWithinLimits( filename ):
   fileSize = os.stat( filename ).st_size
   return licConsts.licFileMinSizeBytes <= fileSize <= licConsts.licFileMaxSizeBytes

def updateLicenseStoreTS():
   licConfig.licenseStoreUpdated = Tac.now()

#--------------------------------------------------------------
# license import URL
#--------------------------------------------------------------
def verifyLicenseFile( mode, filename ):
   # Force mount on objects hwStatus and entMib.
   # Without this if we just pass these objects to Tac.newInstance,
   # then CLI can hang on a object that is mount in-complete.
   LazyMount.force( hwStatus )
   LazyMount.force( entMib )

   externalLicenseIndicator = Tac.newInstance( 'License::ExternalLicenseIndicator' )
   licenseHelper = Tac.newInstance( "License::LicenseHelper", hwStatus, entMib,
                                    externalLicenseIndicator )
   licenseHelper.initialize()
   licenseHelper.licCustomerName = licStatus.customerName
   errMsg = licenseHelper.loadLicense( filename )
   licenseHelper = None
   return errMsg

def importLicenseFromUrl( mode, args ):
   if any(
         sliceItem.postLicenseInstallAction == \
         "restartForwardingPlane"
         for sliceItem in hwOliveStatusSliceDir.values()
   ):
      prompt = (
         "WARNING\r\n"
         "This command will cause links to flap and impact traffic forwarding. "
         "Proceed (y/[N])? "
      )
      if not BasicCliUtil.confirm( mode, prompt, answerForReturn=False ):
         return
   
   url = args[ 'URL' ]
   context = Url.Context( *Url.urlArgsFromMode( mode ) )
   prefix = re.sub( r'[\s:-]+', '.', str(datetime.now()) )
   dstUrl = Url.parseUrl( 'file:{}/{}_{}'.format( licConfig.licenseStorePath, prefix,
                          os.path.basename( str( url ) ) ), context )

   # copy license file to temporary location and do a sanity check before
   # copying it to the actual location
   with tempfile.NamedTemporaryFile( prefix="lic", suffix='.json' ) as inFile:
      try:
         url.get( inFile.name )
      except OSError as e:
         mode.addError( "Failed to get a license from the specified URL: %s" % \
                        e.strerror )
         return
      
      if not isFileWithinLimits( inFile.name ):
         mode.addError( "Failed to install License: File size is invalid" )
         return

      licenseJson = readJson( inFile.name )
      if licenseJson is None:
         mode.addError( "Failed to install License: Could not parse as JSON" )
         return

      if licStatus.customerName \
            and licStatus.customerName != licenseJson[ 'CustomerName' ]:
         mode.addError( "Failed to install License: Customer Name did not match" )
         return

      try:
         if not os.path.exists( licConfig.licenseStorePath ):
            os.makedirs( licConfig.licenseStorePath )
            os.chmod( licConfig.licenseStorePath, 0o777 )
        # Re-arrange file to be in sorted order
         with tempfile.NamedTemporaryFile( prefix="lic", suffix='.json', mode="w"
               ) as outFile:
            outFile.write( json.dumps( licenseJson, sort_keys=True ) )
            outFile.flush()
            dstUrl.put( outFile.name ) # pylint: disable-msg=E1103
            updateLicenseStoreTS()
            licConfig.licenseConfigured = True
            errMsg = verifyLicenseFile( mode, outFile.name )
            if errMsg != "":
               # Raise error to fail the command *BUT* install the license in store
               # LicenseHelper traces will give more information about failure.
               # This helps us debug licenses in the field.
               mode.addError( errMsg )
               return
      except OSError as e:
         mode.addError( "Failed to write license file '%s': %s" % \
                        ( dstUrl, e.strerror ) )
         return

class LicenseImportUrlCmd( CliCommand.CliCommandClass ):
   syntax = 'license import URL'
   data = {
      'license': matcherLicenseConfig,
      'import': 'Import license from a URL',
      'URL': Url.UrlMatcher( lambda fs: fs.supportsRead(),
                             helpdesc='Path to license file',
                             allowAllPaths=True ),
   }
   handler = importLicenseFromUrl

BasicCliModes.EnableMode.addCommandClass( LicenseImportUrlCmd )

#--------------------------------------------------------------------------------
# license update
#--------------------------------------------------------------------------------
class LicenseUpdateCmd( CliCommand.CliCommandClass ):
   syntax = 'license update'
   data = {
      'license': matcherLicenseConfig,
      'update': 'Trigger a check for license',
   }

   @staticmethod
   def handler( mode, args ):
      updateLicenseStoreTS()

BasicCliModes.EnableMode.addCommandClass( LicenseUpdateCmd )

#--------------------------------------------------------------
# clear license  all | ( file name ) [ no-update ]
#--------------------------------------------------------------
def licenseFilesList():
   files  = os.listdir( licConfig.licenseStorePath )
   if not files:
      return { 'None' : None }
   return dict.fromkeys( files )

class ClearLicenseCmd( CliCommand.CliCommandClass ):
   syntax = 'clear license  all | ( file name ) [ no-update ]'
   data = {
         'clear': CliToken.Clear.clearKwNode,
         'license': matcherLicenseConfig,
         'all': 'Clear all license files',
         'file': 'Clear specific license file',
         'name': CliMatcher.DynamicKeywordMatcher( lambda mode: licenseFilesList() ),
         'no-update': 'Do not trigger a check for license'
   }
   @staticmethod
   def handler( mode, args ):
      targetFiles = []
      if 'all' in args:
         targetFiles = os.listdir( licConfig.licenseStorePath )
      elif 'file' in args and args[ 'name' ] != 'None':
         targetFiles = [ args[ 'name' ] ]
      for licenseFile in targetFiles:
         path = os.path.join( licConfig.licenseStorePath, licenseFile )
         try:  
            if os.path.isfile( path ):
               os.unlink( path )
         except OSError:
            mode.addError( "Failed to remove one or more license file(s)")
            return
      if 'no-update' not in args:
         updateLicenseStoreTS()

BasicCliModes.EnableMode.addCommandClass( ClearLicenseCmd )

#--------------------------------------------------------------
# show license status [ all | expired ]
#--------------------------------------------------------------
def getLicenseInstance( licenseInstance, active, expired, future ):
   liModel = LicenseInstance()
   liModel.parameter = licenseInstance.value
   liModel.count = licenseInstance.count
   start = liModel.startTime = licenseInstance.start
   end = liModel.endTime = licenseInstance.expiration
   liModel.source = licenseInstance.source
   now = Tac.utcNow()
   if ( now < start and future ) or ( now > end and expired ) or \
         ( now > start and now < end and active ):
      return liModel
   return None

def getLicenseInstances( feature, active, expired, future ):
   licenseInstances = []
   for li in feature.license.values():
      instance = getLicenseInstance( li, active, expired, future  )
      if instance is not None:
         licenseInstances.append( instance )
   return licenseInstances

def getFeatures( active, expired, future ):
   '''
   Returns a dictionary with of featureModels indexed by featureName
   '''
   features = {}
   licSets = [ licStatus.loadedLicenseSet, getExternalLicenseSet() ]
   for source in licSets:
      for f in source:
         fm = features.get( f, FeatureModel() )
         s = 'CloudEOS' if f == 'vEOS' else f
         fm.featureDescription = featureDesc.get( s, None )
         li = getLicenseInstances( source[ f ], active, expired, future )
         fm.licenses += li
         # Include the featureModel only if feature as at least one license
         # that satisfies the active/expired/future criteria
         if fm.licenses:
            features[ s ] = fm
   return features

def externalLicense( licenseSource ):
   return licenseSource in [ 'Resolving', 'External' , 'DPE', 'PAYG' ]

def getExternalLicInstances( feature ):
   externalLicSet = []
   for licInst in feature.license.values():
      if externalLicense( licInst.source ):
         li = Tac.Value( "License::LicenseInstance" )
         li.value = licInst.value
         li.start = licInst.start
         li.expiration = licInst.expiration
         li.count = licInst.count
         li.daysAllowedPastExpiration = licInst.daysAllowedPastExpiration
         li.source = licInst.source
         externalLicSet.append( li )
   return externalLicSet

def getExternalLicenseSet():
   features = {}
   for featureAgent in licenseInput:
      fi = licenseInput.entityPtr[ featureAgent ]
      # Feature agent with featureName X would have
      # loaded ONLY external license X
      if fi.featureName in fi.loadedLicenseSet:
         extLicInstances = getExternalLicInstances( fi.loadedLicenseSet[ 
                           fi.featureName ] )
         if extLicInstances :
            features[ fi.featureName ] = Tac.newInstance( "License::Feature",
                                         fi.featureName )
            for inst in extLicInstances:
               features[ fi.featureName ].license.enq( inst )
   return features

class LicenseCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show license [ all | expired ]'
   data = {
      'license': matcherLicenseShow,
      'all': 'Status for all licenses',
      'expired': 'Status for licenses that have expired',
   }
   cliModel = LicenseStatusModel

   @staticmethod
   def handler( mode, args ):
      expired = 'expired' in args or 'all' in args
      active = 'expired' not in args
      future = 'all' in args
      ls = LicenseStatusModel()
      ls.deviceSerial = licStatus.sysInfo.deviceSerial
      ls.systemMac = licStatus.sysInfo.systemMac
      ls.platform = ""
      if entMib.root:
         ls.platform = entMib.root.modelName
         if entMib.root.modelName == "vEOS" and entMib.root.vendorType:
            ls.platform = "CloudEOS" + "-" + entMib.root.vendorType
      ls.domainName = "Unknown"
      ls.customerName = licStatus.customerName
      ls.features = getFeatures( active, expired, future )
      return ls

BasicCli.addShowCommandClass( LicenseCmd )

#--------------------------------------------------------------------------------
# show license features
#--------------------------------------------------------------------------------
def activeLicenseExists( featName, licenseSet, now ):
   if featName in licenseSet:
      licenses = licenseSet[ featName ].license
      for lic in licenses.values():
         if lic.start < now < lic.expiration:
            return True
   return False

def getLicenseFeatures():
   def suppressorAvailable( suppressorFeature ):
      for feature in features:
         if feature.featureName == suppressorFeature:
            return activeLicenseExists( suppressorFeature,
                                        feature.loadedLicenseSet,
                                        now )
      return False

   def requiredFeaturesFilter( featureInput ):
      return featureInput.licenseRequired

   def activeFeaturesFilter( featureInput, useSuppressor=False):
      featName = featureInput.featureName
      if activeLicenseExists( featName, featureInput.loadedLicenseSet, now ):
         return True
      if useSuppressor:
         for suppressorFeature in featureInput.suppressor.keys():
            if suppressorAvailable( suppressorFeature ):
               return True
      return False

   def inactiveFeaturesFilter( featureInput ):
      return not activeFeaturesFilter( featureInput, useSuppressor=True )

   now = Tac.utcNow()
   features = []
   for featureAgent in licenseInput:
      featureInput = licenseInput.entityPtr[ featureAgent ]
      features.append( featureInput )

   requiredFeatures = list( filter( requiredFeaturesFilter, features ) )
   active = [ feature.featureName 
              for feature in filter( activeFeaturesFilter, requiredFeatures ) ]
   missing = [ feature.featureName 
               for feature in filter( inactiveFeaturesFilter, requiredFeatures ) ]

   return { "missing" : missing, "active" : active }

class LicenseFeaturesCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show license features'
   data = {
      'license': matcherLicenseShow,
      'features': 'show license features',
   }
   cliModel = LicenseFeatureModel
   
   @staticmethod
   def handler( mode, args ):
      model = LicenseFeatureModel()
      licenseFeatures = getLicenseFeatures()
      model.missing = licenseFeatures[ "missing" ]
      model.active = licenseFeatures[ "active" ] 
      return model

BasicCli.addShowCommandClass( LicenseFeaturesCmd )

#--------------------------------------------------------------------------------
# show license files [ compressed ]
#--------------------------------------------------------------------------------
class LicenseFilesCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show license files [ compressed ]'
   data = {
      'license': matcherLicenseShow,
      'files': 'Show license files',
      'compressed': 'Show compressed base64 encoding',
   }
   cliModel = LicenseFilesModel

   @staticmethod
   def handler( mode, args ):
      compressedOpt = 'compressed' in args
      serialSeen = []
      model = LicenseFilesModel()
      
      if not os.path.exists( licConfig.licenseStorePath ):
         return model

      for fileName in sorted( os.listdir( licConfig.licenseStorePath ) ):
         fullPath = os.sep.join( [ licConfig.licenseStorePath, fileName ] )
         
         if not os.path.isfile( fullPath ) or not isFileWithinLimits( fullPath ) or \
               readJson( fullPath ) is None:
            continue
         
         parsedJson = readJson( fullPath )
         serial = parsedJson[ 'LicenseSerialNumber' ]
         
         if serial in serialSeen:
            continue
         
         serialSeen.append( serial )
         with open( fullPath ) as f:
            text = f.read()

         if compressedOpt:
            body = base64.b64encode( zlib.compress( text.encode( "utf-8" ) )
               ).decode( "ascii" )
         else:
            body = json.dumps( json.loads( text ), indent=4, sort_keys=True )
         model.licenseText[ fileName ] = body

      return model

BasicCli.addShowCommandClass( LicenseFilesCmd )

#--------------------------------------------------------------
# show tech support
#--------------------------------------------------------------
CliPlugin.TechSupportCli.registerShowTechSupportCmd(
   '2017-09-19 08:56:44',
   cmds=[ 'show license all',
          'show license files compressed' ] )

CliPlugin.TechSupportCli.registerShowTechSupportCmd(
   '2019-11-11 02:36:15',
   cmds=[ 'show license features' ] )

def Plugin( entMan ):
   global licConfig, licStatus, hwStatus, entMib, licenseInput, hwOliveStatusSliceDir
   licConfig = LazyMount.mount( entMan, "sys/license/client/configRequest",
                                "License::ConfigRequest", "w" )
   licStatus = LazyMount.mount( entMan, "sys/license/client/status",
                                "License::Status", "r" )
   hwStatus = LazyMount.mount( entMan, "license/hwStatus",
                               "License::HwStatus", "r" )
   entMib = LazyMount.mount( entMan, "hardware/entmib",
                             "EntityMib::Status", "r" )
   licenseInput = LazyMount.mount( entMan, "sys/license/featureLicense",
                              "Tac::Dir", "ri" )
   hwOliveStatusSliceDir = LazyMount.mount( entMan, "hardware/olive/hwStatus/slice",
                                            "Tac::Dir", "ri" )

