# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from __future__ import absolute_import, division, print_function
import hashlib
import os
import shutil
import tempfile
import time
import Tracing
import zipfile
import SignatureFile
import SignatureRequest
import SwiSignLib

logInfo = Tracing.Handle( 'AufLib' ).trace0
logWarn = Tracing.Handle( 'AufLib' ).trace1

AUF_DEV_SIGNING_CERT="/etc/swi-signing-devCA/aufsign.crt"
AUF_DEV_SIGNING_KEY="/etc/swi-signing-devCA/aufsign.key"
AUF_DEV_INT_CA_CERT="/etc/swi-signing-devCA/aufint.crt"
AUF_DEV_INT_CA_KEY="/etc/swi-signing-devCA/aufint.key"

# Placeholders for BUG516648
AUF_ARISTA_SIGNING_CA_CERT="/etc/swi-signature-aufSignCa.crt"
AUF_ARISTA_INT_CA_CERT="/etc/swi-signature-aufIntCa.crt"

certs = [ AUF_DEV_INT_CA_CERT,
          AUF_ARISTA_INT_CA_CERT ]

AUF_SIGN_URL = "https://license.aristanetworks.com/sign/v2/aboot/release/"

def calcSha( _file ):
   with open( _file, mode='rb' ) as f:
      h = hashlib.sha256( f.read() )
      return h.hexdigest()

def copyFileSection( fromFile, toFile, start, length, seek=True ):
   with open( fromFile, 'rb' ) as from_, open( toFile, 'wb' ) as to:
      if seek:
         from_.seek( start )
      to.write( from_.read( length ) )

class Payload( object ): # pylint: disable=useless-object-inheritance
   def __init__( self, sha, imgPath, romOffsets ):
      self.sha = sha
      self.imgPath = imgPath
      self.offset = romOffsets[ 0 ]
      self.size = romOffsets[ 1 ]

class VersionField( object ): # pylint: disable=useless-object-inheritance
   '''
   Smol class to represent a version field which may be a wildcard
   '''

   def __init__( self, v='x' ):
      self.r = None

      if not isinstance( v, str ):
         self.v = int( v )
      elif v == 'x':
         self.v = None
      elif "-" in v:
         t = v.split( "-" )
         self.v = int( t[ 0 ] )
         self.r = [ int( t[ 1 ] ) ]
      else:
         self.v = int( v )

   def __call__( self ):
      if self.r:
         return [ self.v ] + self.r
      return self.v

   def __len__( self ):
      if self.r:
         return len( self.r ) + 1
      return 1

   def __str__( self ):
      if self.r:
         return "-".join( [ str( x ) for x in [ self.v ] + self.r ] )

      return 'x' if self.v is None else str( self.v )

   def __add__( self, x ):
      return VersionField( self.v + x )

   def __sub__( self, x ):
      return VersionField( self.v - x )

   def __lt__( self, other ):
      return True if self.v is None else self.v < other

   def __le__( self, other ):
      return True if self.v is None else self.v <= other

   def __gt__( self, other ):
      return True if self.v is None else self.v > other

   def __ge__( self, other ):
      return True if self.v is None else self.v >= other

   def __eq__( self, other ):
      return True if self.v is None else self.v == other

   def __hash__( self ):
      return hash( ( self.v, self.r ) )

   def __ne__( self, other ):
      return False if self.v is None else self.v != other

class Auf( object ): # pylint: disable=useless-object-inheritance
   """ Class that represents an .auf
   There are two parts, a temporary directory with the contents of an auf, and the
   auf itself. All operations work on the temporary directory, call genAuf to 
   update(remake) the auf file.
   
   If the passed in file already exists then we will parse it and unpack the contents
   into the temproray dir. The temp dir is removed once the object is destroyed
   """

   compatScript = "compat_check.sh"
   layoutFile = "layout"

   def __init__( self, aufFile ):
      self.aufFile = aufFile
      self.sections = {}
      self.layout = {}
      self.line = VersionField()
      self.major = VersionField()
      self.minor = VersionField()
      self.version = 1
      self.name = ""
      self.extraFiles = []

      # Version 2
      self.abootInstallerMinRev_ = 0
      self.abootOnly_ = 0
      self.aufRevision_ = 0

      self.tmpdir = tempfile.mkdtemp()
      logInfo( "Working in temporary directory %s" % self.tmpdir )

      if os.path.isfile( self.aufFile ) and os.path.getsize( self.aufFile ) != 0:
         if not zipfile.is_zipfile( self.aufFile ):
            raise IOError( 'Invalid file format' )
         logInfo( "auf exists, parsing" )

         # pylint: disable-next=consider-using-with
         zf = zipfile.ZipFile( self.aufFile, mode='r' )
         zf.extractall( self.tmpdir )
         zf.close()
         self.parseLayout()
         with open( self.toTmpDir( "info" ), "r" ) as f:
            self.parseInfo( f.readlines() )

      if not os.path.isdir( self.toTmpDir( "payloads" ) ):
         os.makedirs( self.toTmpDir( "payloads" ) )

   def aufVersionCheck( self, fieldName, minVersion ):
      if self.version < minVersion:
         print( "'%s' introduced in AUF version %d. Bumping version." % (
                  fieldName, minVersion ) )
         self.version = minVersion

   @property
   def abootInstallerMinRev( self ):
      return self.abootInstallerMinRev_

   @abootInstallerMinRev.setter
   def abootInstallerMinRev( self, value ):
      self.aufVersionCheck( "abootInstallerMinRev", 2 )
      self.abootInstallerMinRev_ = value

   @property
   def aufRevision( self ):
      return self.aufRevision_

   @aufRevision.setter
   def aufRevision( self, value ):
      self.aufVersionCheck( "aufRevision", 2 )
      self.aufRevision_ = value

   @property
   def abootOnly( self ):
      return self.abootOnly_

   @abootOnly.setter
   def abootOnly( self, value ):
      self.aufVersionCheck( "abootOnly", 2 )
      self.abootOnly_ = value

   def __del__( self ):
      shutil.rmtree( self.tmpdir )

   def delete( self ):
      os.remove( self.aufFile )

   def copyToTmpDir( self, from_, to ):
      try:
         shutil.copyfile( from_, self.toTmpDir( to ) )
      except IOError:
         os.makedirs( os.path.dirname( self.toTmpDir( to ) ) )
         shutil.copyfile( from_, self.toTmpDir( to ) )

   def toTmpDir( self, p ):
      return os.path.join( self.tmpdir, p )

   def parseInfo( self, lines ):
      """
      parse info file, e.g.
      version: 1
      line: 0
      major: 0
      minor: 0
      name: ...
      sections: section1 0ffe1..., section2 f6e0a...
      """

      lines = [ x.strip() for x in lines ]
      for line in lines:
         parts = line.split( ":" )
         if len( parts ) != 2:
            raise SyntaxError( "Invalid info line %s" % line )
         key = parts[ 0 ].strip()
         value = parts[ 1 ].strip()

         if key == "version":
            self.version = int( value )
         elif key == "line":
            self.line = VersionField( value )
         elif key == "major":
            self.major = VersionField( value )
         elif key == "minor":
            self.minor = VersionField( value )
         elif key == "name":
            self.name = value
         elif key == "abootInstallerMinRev":
            self.abootInstallerMinRev_ = int( value )
         elif key == 'aufRevision':
            self.aufRevision_ = int( value )
         elif key == 'abootOnly':
            self.abootOnly_ = int( value )
         elif key == "sections":
            sections = [ x.strip() for x in value.split( "," ) ]
            for s in sections:
               if s == "":
                  continue
               parts =  s.split( " " )
               if len( parts ) != 2:
                  raise SyntaxError( "Invalid section format %s" % s )
               name, sha = parts[ 0 ], parts[ 1 ]
               if len( sha ) != 64:
                  raise TypeError( "Invalid hash" )

               imgLoc = self.toTmpDir( "payloads/%s.img" % name )
               if calcSha( imgLoc ) != sha:
                  raise ValueError( "hash mismatch" )

               offsets = self.layout[ name ]
               self.sections[ name ] = Payload( sha, imgLoc, offsets )
         else:
            raise SyntaxError( "Unknown key %s in info" % key )

   def updateInfo( self ):
      sectionStr = ", ".join(
            [ "%s %s" % ( sectionName, payload.sha ) 
               for sectionName, payload in self.sections.items() ])

      info = ""
      info += "version: %d\n" % self.version
      info += "line: %s\n" % self.line
      info += "major: %s\n" % self.major
      info += "minor: %s\n" % self.minor
      info += "name: %s\n" % self.name
      info += "sections: %s\n" % sectionStr
      if self.version >= 2:
         info += "aufRevision: %d\n" % self.aufRevision
         info += "abootInstallerMinRev: %d\n" % self.abootInstallerMinRev
         info += "abootOnly: %d\n" % self.abootOnly

      with open( self.toTmpDir( 'info' ), 'w' ) as i:
         i.write( info )

   def parseLayout( self ):
      """
      Parsing layout file to get info about sectionName e.g.
      21000:3FFFFF me
      2000:10FFF prefdl
      1000:1FFF mac
      20000:20FFF mfgdata
      1000:20FFF pdr
      A00000:FFFFFF normal
      400000:9EFFFF fallback
      """

      # Remove existing layout defs
      self.layout = {}
      # As the new layout may redefine where current payloads are located, clear
      # current payloads
      self.sections = {}

      with open( self.toTmpDir( self.layoutFile ) ) as f:
         sections = [ x.strip().split(' ') for x in f.readlines() ]

      for section in sections:
         try:
            addrRange = section[ 0 ].split( ':' )
            start, end = (
                  int( addrRange[ 0 ], 16 ),
                  int( addrRange[ 1 ], 16 ) )
            self.layout[ section[ 1 ] ] = ( start, end-start+1 )
         except:
            # pylint: disable-next=raise-missing-from
            raise SyntaxError( "Error parsing layout, line %s" % section )

   def addSectionFromRom( self, sectionName, rom, seek=True ):
      if not os.path.exists( self.toTmpDir( self.layoutFile ) ):
         raise IOError( "missing layout" )

      start, size = self.layout[ sectionName ]

      sectionImg = self.toTmpDir( "payloads/%s.img" % sectionName )
      copyFileSection( rom, sectionImg, start, size, seek )

      checksum = calcSha( sectionImg )
      self.sections[ sectionName ] = Payload( checksum, sectionImg, ( start, size ) )

   def addCompatibilityScript( self, compat ):
      if os.path.exists( self.toTmpDir( self.compatScript ) ):
         logWarn( "Overriding compatability script" )
      self.copyToTmpDir( compat, self.compatScript )

   def getCompatibilityScript( self ):
      if os.path.exists( self.toTmpDir( self.compatScript ) ):
         return self.toTmpDir( self.compatScript )
      return None

   def getSections( self ):
      return self.sections

   def getExtraFiles( self ):
      return self.extraFiles

   def addNotIncludedFiles( self, includedFiles, zf ):
      allFiles = list() # pylint: disable=use-list-literal
      for ( dirpath, _, filenames ) in os.walk( self.tmpdir ):
         allFiles += [ os.path.join( dirpath, f ) for f in filenames ]

      notIncludedFiles = [ os.path.relpath( f, self.tmpdir )
                           for f in list( set( allFiles ) - includedFiles ) ]
      for f in notIncludedFiles:
         zf.write( self.toTmpDir( f ), f )
         logInfo( "Extra file %s included in auf" % f )

      return notIncludedFiles

   def genBuildDate( self ):
      return time.strftime( "%Y%m%dT%H%M%SZ", time.gmtime() )

   def genAuf( self ):
      includedFiles = set()
      self.updateInfo()
      # pylint: disable-next=consider-using-with
      zf = zipfile.ZipFile( self.aufFile, mode='w' )
      zf.write( self.toTmpDir( "info" ), "info" )
      zf.write( self.toTmpDir( self.layoutFile ), self.layoutFile )
      includedFiles.add( self.toTmpDir( "info" ) )
      includedFiles.add( self.toTmpDir( self.layoutFile ) )
      if os.path.exists( self.toTmpDir( self.compatScript ) ):
         zf.write( self.toTmpDir( self.compatScript ), self.compatScript )
         includedFiles.add( self.toTmpDir( self.compatScript ) )
      for sectionName in self.sections:
         loc = "payloads/%s.img" % sectionName
         zf.write( self.toTmpDir( loc ), loc )
         includedFiles.add( self.toTmpDir( loc ) )

      zf.writestr( "version", "BUILD_DATE={}\n".format( self.genBuildDate() ) )
      includedFiles.add( self.toTmpDir( "version" ) )

      # Prior to rev 8, Aboot AUF installer processed "compat-check.sh" instead of
      # "compat_check.sh".
      # Copy the script if AUF doesn't require an installer with the fix.
      if self.abootInstallerMinRev < 8 and \
         os.path.exists( self.toTmpDir( self.compatScript ) ) \
         and not os.path.exists( self.toTmpDir( "compat-check.sh" ) ):
         self.copyToTmpDir( self.toTmpDir( self.compatScript ),
                            "compat-check.sh" )

      self.extraFiles = self.addNotIncludedFiles( includedFiles, zf )
      zf.close()

   def genSignedAuf( self, useDevCA=False, user=None, passwd=None ):
      if os.path.exists( self.toTmpDir( "swi-signature" ) ):
         os.remove( self.toTmpDir( "swi-signature" ) )
      self.genAuf()

      sig = SignatureFile.Signature()
      aufData = SignatureFile.prepareDataForServer( self.aufFile, self.aufRevision,
                                                    sig, product=self.name )
      try:
         if useDevCA:
            sigData = SignatureRequest.getDataFromDevCA( self.aufFile, aufData,
                    devCaKeyPair=( AUF_DEV_SIGNING_CERT, AUF_DEV_SIGNING_KEY ) )
         else:
            sigData = SignatureRequest.getDataFromServer( self.aufFile, aufData,
                    licenseServerUrl=AUF_SIGN_URL,
                    user=user, passwd=passwd )
         SignatureFile.generateSigFileFromServer( sigData, self.aufFile, sig )
      except SignatureRequest.SigningServerError as e:
         if os.path.exists( self.toTmpDir( "swi-signature" ) ):
            os.remove( self.toTmpDir( "swi-signature" ) )
         self.genAuf()
         raise e

   def _isSigned( self, cert ):
      valid = SwiSignLib.verifySwiSignature( self.aufFile, rootCA=cert )
      return valid[ 0 ]

   def isSigned( self, useDevCA=False ):
      cert = AUF_DEV_INT_CA_CERT if useDevCA else AUF_ARISTA_INT_CA_CERT
      return self._isSigned( cert )

   def addLayout( self, layout ):
      if os.path.exists( self.toTmpDir( self.layoutFile )  ):
         logWarn( "Overriding layout" )
      self.copyToTmpDir( layout, self.layoutFile )
      self.parseLayout()

   def getLayout( self ):
      if os.path.exists( self.toTmpDir( self.layoutFile ) ):
         return self.toTmpDir( self.layoutFile )
      return None

   def setAbootVersion( self, version ):
      v = version.split(".")
      try:
         self.line = VersionField( v[ 0 ] )
         self.major = VersionField( v[ 1 ] )
         self.minor = VersionField( v[ 2 ] )
      except:
         # pylint: disable-next=raise-missing-from
         raise SyntaxError( "Invalid version string %s" % version )

   def getName( self ):
      return self.name

   def setName( self, name ):
      self.name = name

   def aufSha( self ):
      return calcSha( self.aufFile )

   def __str__( self ):
      s = "File %s\nAuf version %d, Aboot version %s.%s.%s\n" % \
          ( self.aufFile, self.version, self.line, self.major, self.minor )
      s += "Name: %s\n" % self.name
      if self.version >= 2:
         s += "AUF revision: %d\n" % self.aufRevision
         s += "Aboot installer minimum revision: %d\n" % self.abootInstallerMinRev
         s += "Needs to run in Aboot: %d\n" % self.abootOnly

      if os.path.isfile( self.aufFile ):
         s += "Checksum: %s\n" % self.aufSha()

      signed = SwiSignLib.swiSignatureExists( self.aufFile )
      if signed:
         for cert in certs:
            if self._isSigned( cert ):
               s += "Auf signed:\n"
               s += "\tCert %s\n" % ( cert )

      s += "Sections:\n"
      for sectionName, payload in self.sections.items():
         s += "\t%s: offset: 0x%x bytes, size 0x%x bytes\n\t\tsha256(%s)\n" % \
               ( sectionName, payload.offset, payload.size, payload.sha )
      return s

