# Copyright (c) 2006-2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import fnmatch
import os
import shutil
import stat
from functools import total_ordering

import CliParser
import FileCliUtil
import FileUrl
import SwiSignLib
import Tac
import Tracing
import Url
from CliGlobal import CliGlobal

# pkgdeps: library SecureBoot

t0 = Tracing.trace0

securebootEnabled = CliGlobal( dict( callback=None ) )

def registerSecurebootEnabledCallback( callback ):
   securebootEnabled.callback = callback

@total_ordering
class LocalUrl( Url.Url ):
   # Note that our LocalUrls are not strictly URIs as defined by RFC3986.  For
   # example, we allow 'flash://foo' to mean the file 'foo' on the 'flash:' 
   # filesystem, whereas according to RFC3986 'foo' would be treated as the authority
   # part of that URI.
   def __init__( self, fs, url, rawpathname, pathname, context ):
      Url.Url.__init__( self, fs, url )
      self.context = context

      self.pathname = os.path.normpath( pathname )
      if self.pathname.startswith( '//' ):
         # For some reason, os.path.normpath will preserve two (but not three or 
         # more) initial slashes, but we don't want this behavior.
         self.pathname = self.pathname[ 1: ]

      assert self.pathname.startswith( '/' )
      assert self.pathname == '/' or not self.pathname.endswith( '/' )
      assert not '//' in self.pathname

      # rawpathname is needed for filename completion (see getCompletions() below).
      self.rawpathname_ = rawpathname

   def __str__( self ):
      return f"{self.fs.scheme}{self.pathname}"

   def __eq__( self, other ):
      # Any two LocalUrls that refer to the same file (e.g. 'flash:/foo' and
      # 'flash:/bar/../foo') are considered equal.
      if isinstance( other, LocalUrl ):
         return self.fs == other.fs and self.pathname == other.pathname
      else:
         return NotImplemented

   def __lt__( self, other ):
      if isinstance( other, LocalUrl ):
         lhs = ( self.fs, self.pathname )
         rhs = ( other.fs, other.pathname )
         return lhs < rhs
      else:
         return NotImplemented

   def __hash__( self ):
      return hash( ( self.fs, self.pathname ) )

   def basename( self ):
      return os.path.basename( self.pathname )

   def child( self, f ):
      assert not '/' in f
      childPathname = os.path.join( self.pathname, f )
      return self.__class__( self.fs, 
                             self.fs.scheme + childPathname,
                             childPathname,
                             childPathname,
                             self.context )

   def parent( self ):
      parentPathname = os.path.dirname( self.pathname )
      return self.__class__( self.fs,
                             self.fs.scheme + parentPathname,
                             parentPathname,
                             parentPathname,
                             self.context )

   def isWildcard( self ):
      return '*' in self.basename()

   def expandWildcard( self ):
      parent = self.parent()
      pattern = self.basename()
      return [ parent.child( f )
               for f in parent.listdir() # pylint: disable=no-member
               if fnmatch.fnmatch( f, pattern ) ]

   def getCompletions( self ):
      # The implementation of <tab> and ? completion for filenames is rather nasty,
      # because the CLI infrastructure requires that the original token is a prefix
      # of any completion.  Therefore, we must use self.rawpathname_ rather than
      # self.pathname_, in order that 'flash:/foo/../b' completes to 
      # 'flash:/foo/../bar' rather than 'flash:/bar'.
      try:
         retList = []
         partialName = os.path.basename( self.rawpathname_ )
         dirPathname = os.path.dirname( self.rawpathname_ )
         dirUrl = self.fs.parseUrl( self.fs.scheme + dirPathname,
                                    dirPathname,
                                    self.context )
         try:
            children = dirUrl.listdir( )
         except OSError:
            # error, return nothing
            return []
         if partialName.startswith( '.' ):
            # listdir() never includes '.' or '..', which is generally what we want,
            # but without them 'flash:/foo/..?', for example, will give an
            # unrecognized command error.  Including '.' and '..' only if the 
            # basename starts with '.' is similar to how bash handles this (the
            # difference is that bash doesn't include _any_ filenames that begin with
            # '.' unless the basename starts with '.').
            children += [ '.', '..' ]

         # If the filesystem ignores case, do a case insensitive comparison
         if self.fs.ignoresCase():
            partialName = partialName.lower()

         for f in children + [ '' ]:
            # Including '' is not industry-standard, nor is it what bash does, but it
            # makes sense.  Otherwise '?' after the name of an empty directory with a
            # trailing slash will print '% Unrecognized command', which is silly and
            # misleading.  The industry-standard doesn't really have this problem 
            # because it doesn't include the trailing slash after the completion of a
            # directory name, but that is annoying and un-bash-like.
            # 
            # One downside of doing this is that <tab> after the name of a directory
            # that contains a unique file won't complete to that filename, which is
            # perhaps a little annoying (and un-bash-like), but I think it's not as
            # bad as the alternatives.
            # 
            # If the filesystem ignores case, do a completion match ignoring case, 
            # but return the actual filename case back to the user
            if f.startswith( partialName ) or \
                   self.fs.ignoresCase() and f.lower().startswith( partialName ):
               if partialName:
                  # this rewrites the URL passed in, which requires some 
                  # tricky completion handling in Cli
                  name = self.url[ :-len( partialName ) ] + f
               else:
                  name = self.url + f
               partial = True
               childUrl = dirUrl.child( f )
               if f != '' and childUrl.isdir():
                  # Append a trailing slash to directory names.  This is not
                  # industry-standard, but is like bash and is very useful.
                  name = name + '/'
               else:
                  # This causes <tab>-completing the name of a file (not a directory)
                  # to append a trailing space.  This is not industry-standard, but
                  # is like bash and is user-friendly.
                  partial = False
               retList.append( 
                  CliParser.Completion( name=name, help='', partial=partial ) )
         return retList
      except OSError:
         # The directory dirUrl doesn't exist.
         return []


class LocalFilesystem( Url.Filesystem ):
   urlClassByPath_ = {}

   def parseUrl( self, url, rest, context ):
      # URLs on the current filesystem are treated as relative to the current 
      # directory (regardless of whether or not the filesystem was specified 
      # explicitly).  Other URLs are treated as absolute.
      cliSession = context.cliSession if context else None
      if cliSession and self == cliSession.currentDirectory().fs:
         basedir = cliSession.currentDirectory().pathname
      else:
         basedir = '/'
      pathname = os.path.join( basedir, rest )
      # normalize the path
      npath = os.path.normpath( pathname ).lstrip( '/' )
      # pylint: disable-next=no-member
      urlClass = self.urlClassByPath_.get( npath, self.urlClass_ )
      return urlClass( self, url, rest, pathname, context )

class FlashUrl( LocalUrl ):
   def __init__( self, fs, url, rawpathname, pathname, context=None ):
      LocalUrl.__init__( self, fs, url, rawpathname, pathname, context )
      self.realFilename_ = os.path.join( self.fs.location_, self.pathname[ 1: ] )
      # os.path.normpath(pathname) removes trailing '/' or '/.' at
      # end of pathname (python bug 1707768). here we append '/' back.
      # Otherwise, for example, for "copy file:/tmp/x1 flash:/dir1/."
      # we will copy '/tmp/x1' to a file 'flash:/dir1'!
      npath = os.path.normpath( rawpathname )
      if not os.path.isdir( npath ):
         if rawpathname.endswith( '/.' ) or rawpathname.endswith( '/' ):
            self.realFilename_ += '/'

   def _checkAllowedPath( self ):
      if not ( self.fs.allowedPaths_ and self.context and
               self.context.cliSession ):
         return

      # Import BascCliModes here - we cannot do it at the top level, as
      # UrlPlugin plugins are loaded by "ShowCommand", and BasicCliModes
      # depends on that module, so we must load BascicCliModes to avoid having
      # a partially imported ShowCommands module at the time BasicCliModes is
      # imported.
      import BasicCliModes # pylint: disable=import-outside-toplevel
      if self.allowAllPaths and not isinstance( self.context.cliSession.mode,
                                                BasicCliModes.ConfigModeBase ):
         # Always check for allowed path in config mode
         return

      p = self.pathname
      if not p.endswith( '/' ):
         p += '/'
      if not any( x for x in self.fs.allowedPaths_ 
                  if ( p.startswith( x + '/' ) or
                       self.fs.ignoresCase() and
                       p.lower().startswith( x.lower() + '/' ) ) ):
         errmsg = ( "Only the following paths are allowed: " +
                    ', '.join( self.fs.allowedPaths_ ) )
         if self.context.cliSession.startupConfig():
            # For startup-config just print a warning
            self.context.cliSession.mode.addWarning( errmsg )
            return
         self.context.cliSession.mode.addError( errmsg )
         raise CliParser.AlreadyHandledError()

   def listdir( self ):
      self.checkOpSupported( self.fs.supportsListing )
      return os.listdir( self.realFilename_ )

   def exists( self ):
      return os.path.exists( self.realFilename_ )

   def isdir( self ):
      return os.path.isdir( self.realFilename_ )
      
   def islink( self ):
      return os.path.islink( self.realFilename_ )
      
   def readlink( self ):
      return os.readlink( self.realFilename_ )
      
   def size( self ):
      return os.path.getsize( self.realFilename_ )

   def date( self ):
      return os.path.getmtime( self.realFilename_ )

   def _erasable( self ):
      # XXX Disabling this mechanism for now.  It's not clear that this is a good
      # idea or very user-friendly (especially since it causes
      # 'delete flash:/*' to abort mid-way through without pinpointing the name of
      # the file that caused the failure - see BUG516) , and it's causing problems
      # with the tests.
      #unErasableUrls = [ 'flash:/startup-config', 'flash:/boot-config' ]
      #if str( self ) in unErasableUrls:
      #   return False
      return True

   def permission( self ):
      """Return file permissions as a 4-tuple of bools representing the directory,
      read, write and execute permissions, respectively."""
      s = os.stat( self.realFilename_ )
      return ( self.isdir(),
               ( stat.S_IRUSR & s.st_mode ) != 0,
               ( stat.S_IWUSR & s.st_mode ) != 0,
               ( stat.S_IXUSR & s.st_mode ) != 0 )

   def get( self, dstFn ):
      # We open the source file first, so that if it doesn't exist then the 
      # destination file doesn't get created/truncated.
      self.checkOpSupported( self.fs.supportsRead )
      with open( self.realFilename_, 'rb' ) as srcFile:
         with open( dstFn, 'wb' ) as dstFile:
            shutil.copyfileobj( srcFile, dstFile )

   def getWithFilter( self, dstFn, filterCmd ):
      self.checkOpSupported( self.fs.supportsRead )
      with open( self.realFilename_, 'rb' ) as srcFile:
         with open( dstFn, 'wb' ) as dstFile:
            try:
               Tac.run( [ "bash", "-c", "exec " + filterCmd ],
                        stdin=srcFile, stdout=dstFile,
                        stderr=Tac.CAPTURE )
            except Tac.SystemCommandError as e:
               raise OSError( 0, e.output ) # pylint: disable=raise-missing-from

   def put( self, srcFn, append=False ):
      # We open the source file first, so that if it doesn't exist then the 
      # destination file doesn't get created/truncated.
      self.checkOpSupported( self.fs.supportsWrite )
      if append:
         mode = 'ab'
      else:
         mode = 'wb'

      with open( srcFn, 'rb' ) as srcFile:
         with open( self.realFilename_, mode ) as dstFile:
            shutil.copyfileobj( srcFile, dstFile )

   def renameto( self, dst ):
      self.checkOpSupported( self.fs.supportsRename )
      assert dst.fs == self.fs
      if not self._erasable():
         raise OSError( errno.EACCES, os.strerror( errno.EACCES ) )

      if dst.hasHeader():
         self.writeHeaderAndRenameFile( self.realFilename_, dst.realFilename_, dst )
      else:
         # Note that the industry-standard is not to allow a file or directory to be
         # renamed to the name of any existing file or directory, including itself.
         # However, for simplicity we use the semantics of the 'rename' system call,
         # which allows renaming a file to an existing file, and renaming a 
         # directory to an existing empty directory. See 'man 2 rename' for details.
         os.rename( self.realFilename_, dst.realFilename_ )

   def delete( self, recursive ):
      self.checkOpSupported( self.fs.supportsDelete )
      if not self._erasable():
         raise OSError( errno.EACCES, os.strerror( errno.EACCES ) )
      if self.pathname == '/':
         raise OSError( errno.EPERM, os.strerror( errno.EPERM ) )
      if recursive:
         if self.isdir():
            # Note that this doesn't check that each of the files in the subtree are
            # _erasable().  This is OK because the only non-erasable files are
            # located in the filesystem root.
            shutil.rmtree( self.realFilename_ )
         else:
            os.unlink( self.realFilename_ )
      else:
         os.unlink( self.realFilename_ )

   def mkdir( self ):
      self.checkOpSupported( self.fs.supportsMkdir )
      os.makedirs( self.realFilename_ )

   def rmdir( self ):
      self.checkOpSupported( self.fs.supportsMkdir )
      if self.pathname == '/':
         raise OSError( errno.EPERM, os.strerror( errno.EPERM ) )
      os.rmdir( self.realFilename_ )

   def empty( self ):
      """Create the file if it doesn't exist; truncate it if it does exist."""
      self.checkOpSupported( self.fs.supportsWrite )
      open( self.realFilename_, 'w' ).close() # pylint: disable=consider-using-with

   def localFilename( self, check=True ):
      # Config commands should always call localFileName(), so check here.
      if check:
         self._checkAllowedPath( )
      return self.realFilename_

   # pylint: disable-next=inconsistent-return-statements
   def verifyHash( self, hashName, mode=None, hashInitializer=None ):
      try:
         hashValue = FileCliUtil.chunkedHashCompute( self, hashInitializer )
      except OSError as e:
         mode.addError( f"Error reading {self.url} ({e.strerror})" )
         return

      return hashValue

class FlashFilesystem( LocalFilesystem ): 
   urlClass_ = FlashUrl
   # allow subclasses for specific files
   urlClassByPath_ = {}

   def __init__( self, scheme, location, fsType='flash',
                 allowedPaths=None, permission='rw', mask=None,
                 noPathComponent=False ):
      LocalFilesystem.__init__( self, scheme, fsType, permission, mask=mask,
                                noPathComponent=noPathComponent )
      self.rawlocation_ = location
      self.location_ = None
      self.fsRoot_ = None
      self.filenameOffset_ = -1
      self.linuxDevice_ = self.mountPoint_ = self.linuxFsType_ = None
      self.ignoresCase_ = False
      self.allowedPaths_ = allowedPaths
      self.removable_ = False
      if Url.fsRoot() or location.startswith( '/' ):
         self.fsRootIs( Url.fsRoot() )

   @classmethod
   def registerClass( cls, path, urlClass ):
      assert not path.startswith( '/' )
      cls.urlClassByPath_[ path ] = urlClass

   def fsRootIs( self, fsRoot ):
      '''callback when we know the fs root. Should be called before
      everything else.'''
      t0( "fsRootIs:", fsRoot )
      if self.location_:
         return
      self.fsRoot_ = fsRoot
      if self.rawlocation_.startswith( '/' ):
         # not depending on fsRoot
         location = self.rawlocation_
      else:
         location = os.path.join( fsRoot, self.rawlocation_ )
      self.location_ = location
      t0( "Filesystem", self.scheme, "location", location )
      offset = len( location )
      if location.endswith( '/' ):
         offset -= 1
      self.filenameOffset_ = offset
      ( self.linuxDevice_, self.mountPoint_, self.linuxFsType_ ) = self.mountInfo()
      if self.linuxFsType_ == 'vfat':
         self.ignoresCase_ = True
      else:
         self.ignoresCase_ = False
      try:
         with open( '/sys/block/' +
                    os.path.basename( self.linuxDevice_.rstrip( '0123456789' ) ) +
                    '/removable' ) as f:
            removableFlag = f.read( 1 )
            self.removable_ = ( removableFlag == '1' )
      except OSError:
         pass

   def mountInfo( self ):
      """Extracts information from mount point and populates self"""
      mntpnt = ""
      dev = None
      with open( "/proc/mounts" ) as f:
         for l in f:
            ( tdev, tmntpnt, tfstype ) = l.split()[ :3 ]
            if len( tmntpnt ) > len( mntpnt ) and tmntpnt in self.location_:
               ( dev, mntpnt, fstype ) = ( tdev, tmntpnt, tfstype )

      # pylint: disable-next=pointless-string-statement
      """
      Newer versions of the linux kernel do not show rootfs as the /-node, but
      older ones do.  Without a /-node, this function will not always set
      ( dev, mntpnt, fstype ), so we hardcode the old rootfs-node /-node in as a
      default to preserve behavior from older kernel verisons.
      """
      if dev is None:
         dev = "rootfs"
         mntpnt = "/"
         fstype = "rootfs"
      return ( dev, mntpnt, fstype )

   def stat( self ):
      stat = os.statvfs( self.location_ ) # pylint: disable=redefined-outer-name
      size = stat.f_frsize * stat.f_blocks
      free = stat.f_frsize * stat.f_bavail
      return ( size, free )

   def ignoresCase( self ):
      return self.ignoresCase_

   def realFileSystem( self ):
      return True

   def filenameToUrl( self, filename ):
      if filename.startswith( self.location_ ):
         return self.scheme + filename[ self.filenameOffset_: ]
      return None

   def filenameToUrlQuality( self ):
      """The longer the self.location_ value, the better quality match this is."""
      return self.filenameOffset_

   def _securebootEnabled( self, context ):
      """Returns true if secureboot is enabled"""
      return securebootEnabled.callback and securebootEnabled.callback()

   def validateFile( self, filename, durl=None, context=None ):
      """If this file is going to be copied to the boot-config destination,
      check the SWI signature first"""
      if context and context.cliSession:
         mode = context.cliSession.mode
         if durl:
            bootConfig = FileUrl.bootConfig( mode, createIfMissing=False )
            if ( bootConfig and 'SWI' in bootConfig and
                 str( durl ) == bootConfig[ 'SWI' ] ):
               sigValid, sigError, _ = SwiSignLib.verifySwiSignature( filename,
                                                                      userHint=True )
               if not sigValid:
                  # pylint: disable-next=no-else-raise
                  if self._securebootEnabled( context ):
                     raise OSError( 0, sigError )
                  else:
                     mode.addWarning( sigError )

   @property
   def removable( self ):
      return self.removable_

   @property
   def inconsistent( self ):
      return ( ( self.linuxDevice_, self.mountPoint_, self.linuxFsType_ ) !=
               self.mountInfo() )

def syncFilesystemRoot( fsRoot ):
   """one-time callback when Url.fsRoot() is known."""
   t0( "Filesystem root is", fsRoot )

   _filesystems = Url.filesystemsUnsynced()

   for fs in _filesystems.values():
      if isinstance( fs, FlashFilesystem ):
         t0( "update fsRoot for", fs.scheme )
         fs.fsRootIs( fsRoot )

def syncFlashFilesystems( ):
   """Installs a FlashFilesystem for every subdirectory of the filesystem root
   (typically '/mnt')."""
   t0( "syncFlashFilesystems" )

   def _flashFilesystemExists( location ):
      return os.path.isdir( location ) and os.path.exists( location + '.conf' )

   _filesystems = Url.filesystemsUnsynced()

   # Remove any old registered filesystems _ _filesystems changes so wrap with list.
   for scheme, fs in list( _filesystems.items() ):
      if fs.fsType != 'flash' or scheme == 'flash:' or scheme == 'file:':
         continue
      if not _flashFilesystemExists( fs.location_ ):
         t0( "unregister", fs.scheme )
         Url.unregisterFilesystem( fs )

   # Add any new registered filesystems.
   for f in os.listdir( Url.fsRoot() ):
      if f == 'flash':
         continue
      scheme = f + ':'
      location = os.path.join( Url.fsRoot(), f )
      if _flashFilesystemExists( location ):
         fs = Url.filesystemsUnsynced().get( scheme )
         if fs is not None:
            if fs.inconsistent:
               t0( "unregister", fs.scheme, "(inconsistent)" )
               Url.unregisterFilesystem( fs )
            else:
               continue
         t0( "register", scheme )
         Url.registerFilesystem( FlashFilesystem( scheme, location ) )

   # Note that after this call, _currentDirectory.fs may not be in the _filesystems
   # dict.  This is OK.

def initSessionDirectories( session ):
   homedir = FlashUrl( Url.getFilesystem( 'flash:' ),
                       'flash:/', '/', '/' )
   session.homeDirectoryIs( homedir )
   session.currentDirectoryIs( homedir )

def flashFileUrl( path ):
   return 'flash:/' + path

def fileUrl( path ):
   return 'file:/' + path

def Plugin( context=None ):
   Url.setSyncFlashFilesystems( syncFlashFilesystems )
   Url.setSyncFilesystemRoot( syncFilesystemRoot )
   Url.setInitSessionDirectories( initSessionDirectories )
   Url.registerFilesystem( FlashFilesystem( 'flash:', 'flash' ) )
   Url.registerFilesystem( FlashFilesystem( 'file:', '/',
                                            allowedPaths=( '/tmp',
                                                           '/var/tmp',
                                                           '/var/log' ) ) )
