# Copyright (c) 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import re
import sys
import time

import BasicCli
import BasicCliModes
import CliCommand
import CliMatcher
from CliMode.Vm import VmMode
import CliParser
# pylint: disable-next=consider-using-from-import
import CliPlugin.EthIntfCli as EthIntfCli
import CliPlugin.MacAddr as MacAddr # pylint: disable=consider-using-from-import
# pylint: disable-next=consider-using-from-import
import CliPlugin.PhysicalIntfRule as PhysicalIntfRule
# pylint: disable-next=consider-using-from-import
import CliPlugin.VlanIntfCli as VlanIntfCli
import ConfigMount
import Ethernet
import FileCliUtil
import LazyMount
import ShowCommand
import Tac
import Url

virConfig = None
virStatus = None
allIntfStatusDir = None
allIntfConfigDir = None
globalConfig = BasicCli.GlobalConfigMode

def cpuSupportsVirt():
   cpuinfo = open( "/proc/cpuinfo" ).read() # pylint: disable=consider-using-with
   match = re.search( "flags.*( vmx | svm | hvm).*", cpuinfo )
   return bool( match )

def getMemInfo():
   '''The output of /proc/meminfo is
   MemTotal:       263517764 kB
   MemFree:        14357464 kB
   MemAvailable:   251145068 kB
   Buffers:        25283576 kB
   Cached:         178619420 kB
   ...
   This function parses that output and
   returns a dictionary of field names
   and their values.'''
   try:
      with open( '/proc/meminfo' ) as f:
         meminfo = f.read()
   except OSError:
      return {}
   return { key: int( value.strip( ' kB' ) ) // 1024  for key, value in
            ( line.split( ':' ) for line in meminfo.splitlines() ) }

def enoughMemoryAvailable():
   return getMemInfo().get( 'MemTotal', 0 ) > 1024

def vmsSupported( mode ):
   return enoughMemoryAvailable() and cpuSupportsVirt()

def vmsSupportedGuard( mode, token ):
   if vmsSupported( mode ):
      return None
   else:
      return CliParser.guardNotThisPlatform

nodeVirtualMachine = CliCommand.Node(
      matcher=CliMatcher.KeywordMatcher( 'virtual-machine',
         helpdesc='Virtual Machine Subsystem' ),
      guard=vmsSupportedGuard )

vmNameMatcher = CliMatcher.DynamicNameMatcher( lambda mode: virConfig.vmConfig,
      helpdesc='Virtual Machine name' )

#-------------------------------------------------------------
# Plugin method - Mount the objects we need from Sysdb
#-------------------------------------------------------------
def Plugin( entityManager ):
   global virConfig
   global virStatus
   global allIntfStatusDir
   global allIntfConfigDir
   virConfig = ConfigMount.mount( entityManager, "vm/config",
                                  "Vm::Config", "w" )
   virStatus = LazyMount.mount( entityManager, "vm/status",
                                "Vm::Status", "r" )
   allIntfConfigDir = LazyMount.mount( entityManager, "interface/config/all",
                                       "Interface::AllIntfConfigDir", "r" )
   allIntfStatusDir = LazyMount.mount( entityManager, "interface/status/all",
                                       "Interface::AllIntfStatusDir", "r" )

#-------------------------------------------------------------
# Generate a KVM mac address
#-------------------------------------------------------------
def genMacAddr():
   import random # pylint: disable=import-outside-toplevel
   macStr = "52:54:00:{:02x}:{:02x}:{:02x}".format( random.randint( 0x00, 0xff ),
                                                    random.randint( 0x00, 0xff ),
                                                    random.randint( 0x00, 0xff ) )
   return macStr 

#-----------------------------------------------------------------
# Vnic class -- holds vnic state for each vnic configured for a VM
#-----------------------------------------------------------------
class Vnic:
   def __init__( self, name, backendDev, modelType, macAddr ):
      self.name_ = name
      self.backendDev_ = backendDev
      self.modelType_ = modelType
      self.macAddr_ = macAddr
      self.diskImageFormat_ = None

   def name( self ):
      return self.name_

   def backendDev( self ):
      return self.backendDev_

   def modelType( self ):
      return self.modelType_

   def setBackendDev( self, backendDev ):
      self.backendDev_ = backendDev

   def setModelType( self, modelType ):
      self.modelType_ = modelType

   def macAddr( self ):
      return self.macAddr_

   def setMacAddr( self, macAddr ):
      self.macAddr_ = macAddr


#-------------------------------------------------------------
# Vm class -- holds cli state for each configured VM
#-------------------------------------------------------------
class Vm:
   def __init__( self, vmName ):
      self.name_ = vmName
      self.diskImage_ = ""
      self.enabled_ = False
      self.vnics = {}
      self.memorySize_ = 128
      self.vncPort_ = 0
      self.diskImageFormat_ = None

      # Load config from sysdb, if one already exists
      if vmName in virConfig.vmConfig:
         vmConfig = virConfig.vmConfig[ vmName ]
         self.diskImage_ = vmConfig.diskImage
         self.memorySize_ = vmConfig.memorySize
         self.vncPort_ = vmConfig.vncPort
         for vnicName in vmConfig.vnicConfig:
            vnic = vmConfig.vnicConfig[ vnicName ]
            self.vnics[ vnicName ] = Vnic( vnic.name, 
                                           vnic.backendDev,
                                           vnic.modelType,
                                           vnic.macAddr )
         self.enabled_ = vmConfig.enabled

   #-------------------------------------------------------------
   # Return disk image file path
   #-------------------------------------------------------------
   def diskImage( self ):
      return self.diskImage_

   #-------------------------------------------------------------
   # Set disk image file path
   #-------------------------------------------------------------
   def setDiskImage( self, path ):
      self.diskImage_ = path

   #-------------------------------------------------------------
   # Set disk image file path
   #-------------------------------------------------------------
   def setDiskImageFormat( self, diskImageFormat ):
      self.diskImageFormat_ = diskImageFormat

   # Unset disk image path
   #-------------------------------------------------------------
   def noDiskImage( self ):
      self.setDiskImage( "" )

   #-------------------------------------------------------------
   # Return vnc port
   #-------------------------------------------------------------
   def vncPort( self ):
      return self.vncPort_

   #-------------------------------------------------------------
   # Set vnc port
   #-------------------------------------------------------------
   def setVncPort( self, port ):
      self.vncPort_ = port

   #-------------------------------------------------------------
   # Unset vnc port
   #-------------------------------------------------------------
   def noVncPort( self ):
      self.vncPort_ = 0

   #-------------------------------------------------------------
   # Add new vnic
   #-------------------------------------------------------------
   def addVnic( self, mode, vnicName, backendDevName, macAddr ):
      if macAddr:
         if not Ethernet.isUnicast( macAddr ):
            mode.addError( "Configuration ignored: Mac address must be unicast." )
            return
         else:
            macAddr = Ethernet.convertMacAddrToCanonical( macAddr )

      intfStatus = allIntfStatusDir.intfStatus.get( backendDevName )
      if intfStatus and intfStatus.deviceName != "":
         if intfStatus.forwardingModel != 'intfForwardingModelRouted':
            mode.addError( "Configuration ignored: Device %s is not routed." 
                           % backendDevName )
            return
      else:
         # If the interface doesn't exist in allIntfConfigDir either,
         # it's not a valid interface and we'll just ignore it.
         intfConfig = allIntfConfigDir.intfConfig.get( backendDevName )
         if not intfConfig:
            mode.addError( "Configuration ignored: Device %s does not exist." 
                           % backendDevName )
            return

      if vnicName in self.vnics: 
         vnic = self.vnics[ vnicName ] 
         vnic.setBackendDev( backendDevName )
         if macAddr:
            vnic.setMacAddr( macAddr )
      else:
         if not macAddr:
            macAddr = genMacAddr()
         self.vnics[ vnicName ] = Vnic( vnicName, 
                                        backendDevName, 
                                        "e1000",
                                        macAddr )

   #-------------------------------------------------------------
   # Delete vnic
   #-------------------------------------------------------------
   def delVnic( self, mode, vnicName ):
      if vnicName in self.vnics:
         del self.vnics[ vnicName ]
      else:
         mode.addError( "%s does not exist" % vnicName )

   #-------------------------------------------------------------
   # Return enabled state
   #-------------------------------------------------------------
   def enabled( self ):
      return self.enabled_

   #-------------------------------------------------------------
   # Enable
   #-------------------------------------------------------------
   def enable( self ):
      self.enabled_ = True

   #-------------------------------------------------------------
   # Disable
   #-------------------------------------------------------------
   def disable( self ):
      self.enabled_ = False

   #-------------------------------------------------------------
   # Return name of this VM
   #-------------------------------------------------------------
   def name( self ):
      return self.name_
   
   #-------------------------------------------------------------
   # Return memory size
   #-------------------------------------------------------------
   def memorySize( self ):
      return self.memorySize_

   #-------------------------------------------------------------
   # Set memory size
   #-------------------------------------------------------------
   def setMemorySize( self, size ):
      self.memorySize_ = size

   #-------------------------------------------------------------
   # Return a list of Vm objects for all existing VMs
   #-------------------------------------------------------------
   @staticmethod
   def getAll():
      vms = []
      for vmName in sorted( virConfig.vmConfig.keys() ):
         vm = virConfig.vmConfig[ vmName ]
         if vm is None: # pylint: disable=no-else-continue
            continue
         else:
            vms.append( Vm( vmName ) )
      return vms


#-------------------------------------------------------------
# config-vm mode
#-------------------------------------------------------------
class VmConfigMode( VmMode, BasicCli.ConfigModeBase ):
   name = "Virtual Machine configuration"

   def __init__( self, parent, session, vm ):
      self.vm = vm
      self.session_ = session
      VmMode.__init__( self, self.vm.name() )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

   def _commitVm( self ):
      if not self.vm.name() in virConfig.vmConfig:
         vmConfig = virConfig.newVmConfig( self.vm.name() )
      else:
         vmConfig = virConfig.vmConfig[ self.vm.name() ]

      vmConfig.diskImage = self.vm.diskImage()
      vmConfig.memorySize = self.vm.memorySize()
      vmConfig.vncPort = self.vm.vncPort()

      # Add new vnics or amend existing ones
      for vnicName in self.vm.vnics:
         vnic = self.vm.vnics[ vnicName ]
         if not vnicName in vmConfig.vnicConfig:
            vmConfig.newVnicConfig( vnic.name(),
                                    vnic.backendDev(),
                                    vnic.modelType(),
                                    vnic.macAddr() )
         else:
            vmConfig.vnicConfig[ vnicName ].backendDev = vnic.backendDev()
            vmConfig.vnicConfig[ vnicName ].modelType = vnic.modelType()
            vmConfig.vnicConfig[ vnicName ].macAddr = vnic.macAddr()

      # Clean up deleted vnics
      for vnicName in vmConfig.vnicConfig:
         if not vnicName in self.vm.vnics:
            del vmConfig.vnicConfig[ vnicName ]

      vmConfig.enabled = self.vm.enabled()

   def onExit( self ):
      self._commitVm( )
      BasicCli.ConfigModeBase.onExit( self )

#--------------------------------------------------------------------------------
# [ no | default ] virtual-machine VM_NAME
#--------------------------------------------------------------------------------
def _gotoVmConfigMode( mode, args ):
   vm = Vm( args[ 'VM_NAME' ] )
   childMode = mode.childMode( VmConfigMode, vm=vm )
   mode.session_.gotoChildMode( childMode )

def destroyVm( mode, args ):
   del virConfig.vmConfig[ args[ 'VM_NAME' ] ]

class VirtualMachineVmnameCmd( CliCommand.CliCommandClass ):
   syntax = 'virtual-machine VM_NAME'
   noOrDefaultSyntax = syntax
   data = {
      'virtual-machine': nodeVirtualMachine,
      'VM_NAME': vmNameMatcher
   }
   handler = _gotoVmConfigMode
   noOrDefaultHandler = destroyVm

BasicCliModes.GlobalConfigMode.addCommandClass( VirtualMachineVmnameCmd )

def showOneVm( format_, vm, detail ):
   stateStr = "Stopped"
   if vm.name() in virStatus.vmStatus:
      vmStatus = virStatus.vmStatus[ vm.name() ]
      if vmStatus.state != "":         
         stateStr = vmStatus.state

   if not detail:
      print( format_ % ( vm.name(), 'Yes' if vm.enabled() else 'No',
                        stateStr ) )
   else:
      print( "Virtual Machine: %s" % vm.name() )
      print( "   Enabled:             %s" % ( 'Yes' if vm.enabled() else 'No' ) )
      print( "   State:               %s" % stateStr )
      print( "   Disk Image:          %s" % vm.diskImage() )
      print( "   Memory Size:         %dMB" % vm.memorySize() )
      print( "   VNC port:            %s" % ( str( vm.vncPort() ) if
         vm.vncPort() != 0 else 'None' ) )

      for vnicName in sorted( vm.vnics.keys() ):
         vnic = vm.vnics[ vnicName ]
         if not vnic: # pylint: disable=no-else-continue
            continue
         else:
            print( "   Virtual Nic: %s" % vnic.name() )
            print( "      Mac Address:    %s" % vnic.macAddr() )
            print( "      Device:         %s" % vnic.backendDev() )
            print( "      Model Type:     %s" % vnic.modelType() )
      print( "" )

#--------------------------------------------------------------------------------
# show virtual-machine [ detail ]
#--------------------------------------------------------------------------------
def showVm( mode, args ):
   vms = Vm.getAll()
   format_ = '%-20s %-10s %-10s'
   detail = 'detail' in args
   if not detail:
      print( format_ % ( 'VM Name', 'Enabled' , 'State' ) )
      print( format_ % ( '-------', '-------' , '-----' ) )

   for vm in vms:
      showOneVm( format_, vm, detail )

class VirtualMachineCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show virtual-machine [ detail ]'
   data = {
      'virtual-machine': "VM's status",
      'detail': 'Show additional information',
   }
   handler = showVm

BasicCli.addShowCommandClass( VirtualMachineCmd )

#--------------------------------------------------------------------------------
# [ no | default ] config-file URL
#--------------------------------------------------------------------------------
class ConfigFileUrlCmd( CliCommand.CliCommandClass ):
   syntax = 'config-file URL'
   noOrDefaultSyntax = 'config-file ...'
   data = {
      'config-file': "VM's libvirt configuration file (overrides other settings)",
      'URL': Url.UrlMatcher( lambda fs: fs.supportsListing(),
         helpdesc='Path to libvirt config file' ),
   }
   hidden = True

   @staticmethod
   def handler( mode, args ):
      mode.addError( "'config-file' is deprecated. "
                     "Please configure the VM using the CLI." )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      pass

VmConfigMode.addCommandClass( ConfigFileUrlCmd )

#--------------------------------------------------------------------------------
# [ no | default ] disk-image URL [ image-format IMAGE_FORMAT ]
#--------------------------------------------------------------------------------
class DiskImageUrlCmd( CliCommand.CliCommandClass ):
   syntax = 'disk-image URL [ image-format IMAGE_FORMAT ]'
   noOrDefaultSyntax = 'disk-image ...'
   data = {
      'disk-image': 'Add Virtual Machine disk image',
      'URL': Url.UrlMatcher( lambda fs: fs.supportsListing(),
         helpdesc='Path to disk image' ),
      'image-format': CliCommand.Node(
         matcher=CliMatcher.KeywordMatcher( 'image-format',
            helpdesc='Virtual Machine disk format' ),
         hidden=True ),
      'IMAGE_FORMAT': CliMatcher.EnumMatcher( {
         'raw': 'Raw image format',
         'iso': 'Iso image format',
         'qcow': 'Qcow image format',
         'qcow2': 'Qcow2 image format',
         'vmdk': 'Vmdk image format',
      } )
   }

   @staticmethod
   def handler( mode, args ):
      url = args[ 'URL' ]
      FileCliUtil.checkUrl( url )
      mode.vm.setDiskImage( url.localFilename() )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mode.vm.noDiskImage()

VmConfigMode.addCommandClass( DiskImageUrlCmd )

#--------------------------------------------------------------------------------
# [ no | default ] virtual-nic VNIC_NUMBER ( INTF | VLAN_INTF ) [ MAC_ADDR ]
#--------------------------------------------------------------------------------
class VirtualNicNumberCmd( CliCommand.CliCommandClass ):
   syntax = 'virtual-nic VNIC_NUMBER ( INTF | VLAN_INTF ) [ MAC_ADDR ]'
   noOrDefaultSyntax = 'virtual-nic VNIC_NUMBER ...'
   data = {
            'virtual-nic': 'Add virtual NIC',
            'VNIC_NUMBER': CliMatcher.IntegerMatcher( 1, 4,
               helpdesc='Virtual NIC ID' ),
            'INTF': PhysicalIntfRule.PhysicalIntfMatcher( 'Management' ),
            'VLAN_INTF': VlanIntfCli.VlanIntf.matcher,
            'MAC_ADDR': MacAddr.macAddrMatcher
          }

   @staticmethod
   def handler( mode, args ):
      if 'INTF' in args:
         intf = EthIntfCli.EthPhyIntf( args[ 'INTF' ], mode )
      elif 'VLAN_INTF' in args:
         intf = args[ 'VLAN_INTF' ]
      else:
         assert False, 'Didn\'t get an interface'
         
      vnicName = "vnic%d" % args[ 'VNIC_NUMBER' ]
      mode.vm.addVnic( mode, vnicName, intf.name, args.get( 'MAC_ADDR' ) )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      vnicName = "vnic%d" % args[ 'VNIC_NUMBER' ]
      mode.vm.delVnic( mode, vnicName )

VmConfigMode.addCommandClass( VirtualNicNumberCmd )

#--------------------------------------------------------------------------------
# [ no | default ] vnc-port VNC_PORT
#--------------------------------------------------------------------------------
class VncPortVncportCmd( CliCommand.CliCommandClass ):
   syntax = 'vnc-port VNC_PORT'
   noOrDefaultSyntax = 'vnc-port ...'
   data = {
      'vnc-port': 'Set VNC server port',
      'VNC_PORT': CliMatcher.IntegerMatcher( 5900, 5910,
         helpdesc='VNC server\'s tcp port number' ),
   }

   @staticmethod
   def handler( mode, args ):
      mode.vm.setVncPort( args[ 'VNC_PORT' ] )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mode.vm.noVncPort()

VmConfigMode.addCommandClass( VncPortVncportCmd )

#--------------------------------------------------------------------------------
# [ no | default ] enable
#--------------------------------------------------------------------------------
class EnableCmd( CliCommand.CliCommandClass ):
   syntax = 'enable'
   noOrDefaultSyntax = syntax
   data = {
      'enable': 'Enable VM',
   }

   @staticmethod
   def handler( mode, args ):
      mode.vm.enable()

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mode.vm.disable()

VmConfigMode.addCommandClass( EnableCmd )

#--------------------------------------------------------------------------------
# memory-size SIZE
#--------------------------------------------------------------------------------
def memRangeFn( mode, context ):
   return ( 32, getMemInfo().get( 'MemTotal', 4094 ) // 2 )

class MemorySizeSizeCmd( CliCommand.CliCommandClass ):
   syntax = 'memory-size SIZE'
   data = {
      'memory-size': 'Set memory size',
      'SIZE': CliMatcher.DynamicIntegerMatcher( rangeFn=memRangeFn,
                                                helpdesc='Memory size (in MB)' ),
   }

   @staticmethod
   def handler( mode, args ):
      size = int( args[ 'SIZE' ] )

      memInfo = getMemInfo()

      memTotal = memInfo.get( 'MemTotal' )
      memAvailable = memInfo.get( 'MemAvailable' )
      if memTotal and memAvailable:
         if memTotal <= 1024:
            mode.addError(
               'Configuration ignored: '
               'System does not have enough memory to run VMs.' )
            return

         if memAvailable < size:
            mode.addError( 'Configuration ignored: '
                           'System does not have enough ' 
                           'free memory available '
                           '(%dMB currently available).' 
                           % memAvailable )
            return
         mode.vm.setMemorySize( size )
      else:
         mode.addError( 'Configuration ignored: '
                        'Unable to determine available free memory' )

VmConfigMode.addCommandClass( MemorySizeSizeCmd )

#--------------------------------------------------------------------------------
# virtual-machine VM_NAME console [ ESCAPE_CHAR ]
#--------------------------------------------------------------------------------
def runVmConsole( mode, args  ):
   vmName = args[ 'VM_NAME' ]
   escapeChar = args[ 'ESCAPE_CHAR' ]
   if vmName in virStatus.vmStatus:
      vmStatus = virStatus.vmStatus[ vmName ]
      if vmStatus.state == "Running":
         if not escapeChar:
            escapeChar = 29
         try:
            # consolePtyPath as set by QMP is pty:/dev/pts/XX
            # consolePtyPath[4:] chops off 'pty:' from the string
            Tac.run( [ '/usr/bin/vConsole', vmName, 
                       vmStatus.consolePtyPath[4:], str(escapeChar) ],
                     asRoot=True,
                     stdin=sys.stdin,
                     stdout=sys.stdout )
         except Tac.SystemCommandError as e:
            if e.output:
               mode.addError( e.output )
      else:
         mode.addError( "Virtual Machine %s is not running" % vmName )
   else:
      mode.addError( "Virtual Machine %s does not exist" % vmName )

class VirtualMachineVmnameConsoleCmd( CliCommand.CliCommandClass ):
   syntax = 'virtual-machine VM_NAME console [ ESCAPE_CHAR ]'
   data = {
      'virtual-machine': nodeVirtualMachine,
      'VM_NAME': vmNameMatcher,
      'console': "Attach to VM's serial console",
      'ESCAPE_CHAR': CliCommand.Node(
         matcher=CliMatcher.IntegerMatcher( 27, 31, helpdesc='Escape Character' ),
         hidden=True ),
   }
   handler = runVmConsole

BasicCliModes.EnableMode.addCommandClass( VirtualMachineVmnameConsoleCmd )

#--------------------------------------------------------------------------------
# virtual-machine VM_NAME restart
#--------------------------------------------------------------------------------
def restartVm( mode, args ):
   vmName = args[ 'VM_NAME' ]
   if vmName in virConfig.vmConfig:
      vmConfig = virConfig.vmConfig[ vmName ]
      if vmConfig and vmConfig.enabled:
         # Disable
         vmConfig.enabled = False

         # We need to make sure the QEMU process has finished handling the SIGTERM
         time.sleep(1)

         # Enable
         vmConfig.enabled = True
      else:
         mode.addError( "Virtual Machine %s is not enabled" % vmName )
   else:
      mode.addError( "Virtual Machine %s does not exist" % vmName )

class VirtualMachineVmnameRestartCmd( CliCommand.CliCommandClass ):
   syntax = 'virtual-machine VM_NAME restart'
   data = {
      'virtual-machine': nodeVirtualMachine,
      'VM_NAME': vmNameMatcher,
      'restart': 'Restart VM',
   }
   handler = restartVm

BasicCliModes.EnableMode.addCommandClass( VirtualMachineVmnameRestartCmd )
