#!/usr/bin/env python3
# Copyright (c) 2005-2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#

from AaaPluginLib import TR_ERROR, TR_AUTHEN
from BothTrace import traceX as bt
from BothTrace import Var as bv

import errno
import os
import stat
import sys

def logError( exception, quiet=True ):
   # pylint: disable-next=consider-using-f-string
   msg = "Error setting up symlinks: %s" % str( exception )
   if not quiet:
      print( msg, file=sys.stderr )
   else:
      bt( TR_ERROR, msg )

def chownUsersFiles( path, uid, gid, quiet=True ):
   '''
   This method is expected to be called for both /home/<user> and
   /mnt/flash/home/<user>. If the file is a symlink, this method
   should change the owner:group of the symlink file and not of
   the target. os.chown() changes the uid:gid of the target
   '''
   errors = 0

   def chownFile( filePath ):
      if os.path.islink( filePath ):
         os.lchown( filePath, uid, gid )
      else:
         os.chown( filePath, uid, gid )

   try:
      for dirpath, subdirs, files in os.walk( path ):
         chownFile( dirpath )
         for sd in subdirs:
            dn = os.path.join( dirpath, sd )
            try:
               # changing ownership of a broken symlink returns ENOENT.
               chownFile( dn )
            except OSError as e:
               if e.errno != errno.ENOENT:
                  errors += 1
                  logError( e, quiet )
         for filename in files:
            fn = os.path.join( dirpath, filename )
            try:
               chownFile( fn )
            except OSError as e:
               if e.errno != errno.ENOENT:
                  errors += 1
                  logError( e, quiet )
   except OSError as e:
      errors += 1
      logError( e, quiet )
   return errors

def getFsType( mntPt ):
   '''
   Parse a line like below from /proc/mounts to know the filesystem on /mnt/flash
   /dev/sda1 /mnt/flash ext4 rw,noatime,grpid,commit=1 0 0
   '''
   # check what is /mnt/flash in /proc/mounts in the EOS
   flashMntPt = os.path.join( '/mnt', mntPt )
   with open( '/proc/mounts' ) as f:
      for line in f:
         parts = line.split()
         if parts[ 1 ] == flashMntPt:
            return parts[ 2 ]
   # In breadth test env, FILESYSTEM_ROOT is tmpfs
   return "tmpfs"

def ownerIdChanged( flashDir, pwent_uid, pwent_gid ):
   statbuf = os.stat( flashDir )
   uidFlash = statbuf.st_uid
   gidFlash = statbuf.st_gid
   return uidFlash != pwent_uid or gidFlash != pwent_gid

def setupSymlinksToPersistentHomeDir( userName, uid, gid, quiet=True ):
   '''
   for ext4 /mnt/flash
   ====================
   1. User 'admin' can read files of 'foo' by default.
   2. if 'read' is disabled for 'others' on /mnt/flash/home/foo/file1
      then user 'admin' can't read the file via symlink /home/foo/file1
      as well

   for vfat /mnt/flash
   ===============
   1. owner is always root, group is eosadmin on EOS
   2. vfat doesn't support chmod or chown, so unless permissions are blocked,
      at /home/ directory, user 'admin' can read the files of 'foo' on /home
      or on /mnt/flash/home/foo
   3. We cannot create a symlink from a vfat fs to file on any other fs
   '''
   errors = 0
   fsRoot = os.environ.get( 'FILESYSTEM_ROOT', '/mnt' )
   flashRoot = os.path.join( fsRoot, 'flash' )
   flashDir = os.path.join( flashRoot, 'home', userName )
   homeDir = os.path.join( '/home', userName )

   # This user may not have homedir on /mnt/flash or
   # maybe the device mounted at /mnt/flash disappeared
   if not ( os.path.isdir( flashDir ) and os.path.isdir( homeDir ) ):
      return errors

   # We'll setup symlinks from /home only if /mnt/flash/home is a directory
   # and /mnt/flash/home/<user> is also a directory
   if ( os.path.islink( os.path.join( flashRoot, 'home' ) ) or
        os.path.islink( flashDir ) ):
      if quiet:
         bt( TR_AUTHEN, "User", bv( userName ),
             "home directory on flash is not a directory." )
      return errors

   for fl in os.listdir( flashDir ):
      # Do not symlink /mnt/flash/home/<user>/.ssh to /home/<user>/.ssh
      if fl == ".ssh":
         continue
      symlinkName =  os.path.join( homeDir, fl )

      # user may have created new files or directories under
      # /mnt/flash/home/<user>. As the user logs into a new session, create 
      # symlinks for the files in /mnt/flash/home/<user> in /home/<user>
      try:
         statinfo = os.lstat( symlinkName )
      except OSError as e:
         if e.errno != errno.ENOENT:
            errors += 1
            logError( e, quiet )
            continue
      else:
         # Already a symlink? we are done with this file. Or
         # User created a file or directory on flash with same name as that 
         # of file or a directory on tmpfs. Leave tmpfs entry as is
         if ( stat.S_ISLNK( statinfo.st_mode ) or
              stat.S_ISREG( statinfo.st_mode ) or
              stat.S_ISDIR( statinfo.st_mode ) ):
            continue
 
      # No symlink on tmpfs. Create one now.
      target = os.path.join( flashDir, fl )
      try:
         os.symlink( target, symlinkName )
         os.lchown( symlinkName, uid, gid )
      except OSError as e:
         errors += 1
         logError( e, quiet )
 
   # uid may change for a user post reboot. Ensure uid of files on flash 
   # for this user matches with uid of /home/user. If the file is a
   # symlink, change the uid:gid of symlink and not the target
   fstype = getFsType( flashRoot )
   if fstype != 'vfat' and ownerIdChanged( flashDir, uid, gid ):
      errors += chownUsersFiles( flashDir, uid, gid, quiet )
   return errors
      
def removeBrokenSymlinks( userName, quiet ):
   errors = 0
   homeDir = os.path.join( '/home', userName )
   for ent in os.listdir( homeDir ):
      fl = os.path.join( homeDir, ent )
      if os.path.islink( fl ) and not os.path.exists( fl ):
         try:
            # This is a broken symlink as resolved symlink doesn't exist. 
            os.unlink( fl )
         except OSError as e:
            if e.errno != errno.ENOENT:
               errors += 1
               logError( e, quiet )
   return errors

def createHomeLinks( userName, uid, gid, quiet=True ):
   errors = 0
   # User has a home dir on flash. Setup symlinks
   errors += setupSymlinksToPersistentHomeDir( userName, uid, gid, quiet )
   # User may have removed his/her home dir on flash. In that case
   # we need to remove broken symlinks in /home/user dir
   errors += removeBrokenSymlinks( userName, quiet )
   return errors
