# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=arguments-renamed

from ApiBaseModels import BaseModel
from ApiBaseModels import Str
from ApiBaseModels import Int
from ApiBaseModels import Bool
from ApiBaseModels import Float
from ApiBaseModels import List

import Tac
import Tracing
import Logging
import OpenStackLogMsgs

traceHandle = Tracing.Handle( 'OpenStackUwsgiServer' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

SwitchInterface = Tac.Type( 'VirtualNetwork::Client::SwitchInterface' )
SyncStatus = Tac.Type( 'OpenStack::SyncStatus' )

class AgentModel( BaseModel ):
   uuid = Str(
      tacName='agentUuid',
      inputOk=False,
      description='Agent UUID' )
   agentMode = Str(
      inputOk=False,
      description='Provision or visibility mode' )
   supportedApis = List(
      tacName='supportedApiTypes',
      inputOk=False,
      description='API types this agent supports',
      valueType=str )
   isLeader = Bool(
      tacName=None,
      inputOk=False,
      description='CVX leader status' )

class ServiceEndPointModel( BaseModel ):
   name = Str( description='Endpoint name' )
   authUrl = Str( description='Endpoint auth URL' )
   user = Str( description='Endpoint username' )
   password = Str( description='Endpoint password' )
   tenant = Str( description='Endpoint tenant' )

   def toSysdb( self, region ):
      ep = region.serviceEndPoint.newMember( self.name )
      ep.authUrl = self.authUrl
      ep.user = self.user
      ep.password = self.password
      ep.tenant = self.tenant

class RegionModel( BaseModel ):
   name = Str( description='Region name' )
   syncStatus = Str(
      inputOk=False,
      description='Sync status' )
   syncInterval = Float( description='Synchronization interval' )
   syncHeartbeat = Float(
      inputOk=False,
      description='Synchronization heartbeat' )

   def fromSysdb( self, rConfig, rStatus ): # pylint: disable-msg=W0221
      super().fromSysdb( rConfig )
      if rStatus:
         self.syncStatus = rStatus.syncStatus

   def toSysdb( self, region ):
      if 'interval' in self.__values__:
         region.syncInterval = self.interval

class SyncModel( BaseModel ):
   requester = Str( description='Id of the client requesting the lock' )
   requestId = Str( description='The request id for the lock' )

   def toSysdb( self, region ):
      requester = self.requester
      requestId = self.requestId
      # Compute our proper sync status
      # In order to start a sync both requester and requestId must be
      # set as valid values.
      if requester and requestId:
         region.requester = requester
         region.requestId = requestId
         region.syncStatus = SyncStatus.syncInProgress
         region.lastSyncStartTime = Tac.now()
      elif requester == "" and requestId == "":
         region.syncStatus = SyncStatus.syncComplete
         region.requestId = ""
         region.requester = ""

class TenantModel( BaseModel ):
   tenantId = Str(
      apiName='id',
      tacName='id',
      description='Tenant ID' )

class NetworkModel( BaseModel ):
   networkId = Str(
      apiName='id',
      tacName='id',
      description='Network ID' )
   tenantId = Str(
      tacName=None,
      description='Tenant this network belongs to' )
   networkName = Str(
      apiName='name',
      description='Network Name' )
   shared = Bool( description='Shared network' )

   def fromSysdb( self, t ):
      super().fromSysdb( t )
      if t.tenant:
         self.tenantId = t.tenant.id

   def toSysdb( self, network ):
      for tacName, value in self.getPopulatedModelFields( key='tacName' ):
         # Set mutable fields
         if tacName in ( 'networkName', 'shared' ):
            setattr( network, tacName, value )

class SwitchportModel( BaseModel ):
   switchId = Str(
      apiName='id',
      description='Switch ID' )
   interface = Str( description='Switchports' )

   def fromSysdb( self, t ): # pylint: disable=useless-super-delegation
      super().fromSysdb( t )

   def toSysdb( self, switchport ):
      pass

class PortModel( BaseModel ):
   portId = Str(
     apiName='id',
     tacName='id',
     description='Port ID' )
   name = Str(
     tacName='portName',
     description='Port name' )
   vlanType = Str(
     tacName='portVlanType',
     description='Port VLAN type' )
   networkId = Str(
      tacName=None,
      description='Tenant network' )
   instanceId = Str(
      tacName=None,
      description='The VM/DHCP/Router ID this port belongs to' )
   instanceType = Str(
      tacName=None,
      description='Either vm/dhcp/router/baremetal' )
   tenantId = Str(
      tacName=None,
      description='Tenant ID' )

   def fromSysdb( self, t, iType ): # pylint: disable-msg=W0221
      super().fromSysdb( t )
      # We must serialize ptr manually
      if t.network:
         self.networkId = t.network.id
      vInstance = t.instance
      if vInstance:
         self.instanceId = vInstance.id
         if vInstance.tenant:
            self.tenantId = vInstance.tenant.id
      self.instanceType = iType
      self.vlanType = t.portVlanType

   # pylint: disable-next=arguments-differ
   def toSysdb( self, region, pId, instance, network ):
      port = region.newPort( pId, self.vlanType or 'allowed' )
      if self.name:
         port.portName = self.name
      port.instance = instance
      instance.port.addMember( port )
      port.network = network

class SegmentModel( BaseModel ):
   segmentId = Str(
      apiName='id',
      tacName='id',
      description='Segmentation ID' )
   segmentationType = Str(
      apiName='type',
      tacName='type',
      description='Segmentation Type' )
   segmentationId = Int( description='Segmentation Type Id' )
   networkId = Str(
      tacName=None,
      description='Network ID' )
   segmentType = Str(
      tacName=None,
      description='Indicates whether the segment type is static or dynamic.' )

   # pylint: disable-next=inconsistent-return-statements
   def toSysdb( self, region, network, vlanPool ): # pylint: disable-msg=W0221
      segmentId = self.segmentId
      segmentationType = self.segmentationType
      segmentationId = self.segmentationId
      # If the arista_vlan type driver is being used, check that the vlan id is
      # in the configured pool
      if ( segmentationType == 'vlan' and vlanPool and
           segmentationId not in ( list( vlanPool.availableVlan ) +
                                   list( vlanPool.allocatedVlan ) ) ):
         # pylint: disable-next=consider-using-f-string
         invalidVlanMessage = ( "VLAN segmentation id %d is not available" %
                                segmentationId )
         warn( invalidVlanMessage )
         Logging.log( OpenStackLogMsgs.CVX_OPENSTACK_INVALID_NETWORK_VLAN,
                      segmentationId,
                      network.id )
         return segmentationId
      segment = region.newSegment( segmentId, segmentationType, segmentationId,
                                   network.id )
      segmentType = self.segmentType
      if 'static' == segmentType:
         network.staticSegment.addMember( segment )
      elif 'dynamic' == segmentType:
         network.dynamicSegment.addMember( segment )

   def fromSysdb( self, region, segment ): # pylint: disable=arguments-differ
      super().fromSysdb( segment )
      self.networkId = segment.networkId
      network = region.network.get( segment.networkId )
      if network is None:
         self.segmentType = 'unknown'
         return
      if segment.id in network.staticSegment:
         self.segmentType = 'static'
      elif segment.id in network.dynamicSegment:
         self.segmentType = 'dynamic'

class PortToHostBindingModel( BaseModel ):
   ''' Internal implementation. Do not expose via any APIs (yet). '''
   host = Str( description='Port to host binding' )
   segments = List(
      apiName='segment',
      tacName='segment',
      description='Network segment the port is connected to.',
      valueType=SegmentModel )

   # pylint: disable-msg=W0221
   def toSysdb( self, region, binding ):
      host = self.host
      hostBinding = binding.newPortToHostBinding( host )
      segments = self.segments or []

      for level, segment in enumerate( segments ):
         segmentId = segment.segmentId
         if segmentId in region.segment:
            hostBinding.segment[ int( level ) ] = region.segment[ segmentId ]

   def fromSysdb( self, region, portBinding ):
      self.host = portBinding.host
      levels = list( portBinding.segment )
      segments = []
      for l in sorted( levels ):
         segmentModel = SegmentModel()
         segmentModel.fromSysdb( region, portBinding.segment[ l ] )
         segments.append( segmentModel )
      self.segments = segments

class PortToSwitchInterfaceBindingModel( BaseModel ):
   ''' Internal implementation. Do not expose via any APIs (yet). '''
   host = Str( description='Switch hostname' )
   switch = Str(
      tacName=None,
      description='Switch ID' )
   interface = Str(
      tacName=None,
      description='Switch Interface' )
   segments = List(
      apiName='segment',
      tacName='segment',
      valueType=SegmentModel,
      description='Network segment the port is connected to.' )

   # pylint: disable-msg=W0221
   def toSysdb( self, region, binding ):
      switch = self.switch
      interface = self.interface
      host = self.host
      switchInterface = SwitchInterface( switch, interface )
      switchInterfaceBinding = binding.newPortToSwitchInterfaceBinding(
                                  switchInterface, host )
      segments = self.segments or []

      for level, segment in enumerate( segments ):
         segmentId = segment.segmentId
         if segmentId in region.segment:
            switchInterfaceBinding.segment[ int( level ) ] = region.segment[
                                                                        segmentId ]

   def fromSysdb( self, region, portBinding ):
      self.host = portBinding.host
      switchInterface = portBinding.switchInterface
      self.switch = switchInterface.switchId
      self.interface = switchInterface.interface
      levels = list( portBinding.segment )
      segments = []
      for l in sorted( levels ):
         segmentModel = SegmentModel()
         segmentModel.fromSysdb( region, portBinding.segment[ l ] )
         segments.append( segmentModel )
      self.segments = segments

class PortBindingModel( BaseModel ):
   portId = Str( description='Port ID' )
   hostBinding = List(
      tacName='portToHostBinding',
      valueType=PortToHostBindingModel,
      description='Host to which the port is bound to' )
   switchBinding = List(
      valueType=PortToSwitchInterfaceBindingModel,
      tacName='portToSwitchInterfaceBinding',
      description='Host to which the port is bound to' )

   # pylint: disable-msg=W0221
   def toSysdb( self, region, portId ):
      portBindings = self.hostBinding or []
      for binding in portBindings:
         p = region.port[ portId ]
         b = region.newPortBinding( portId, p )
         binding.toSysdb( region, b )

      portBindings = self.switchBinding or []
      for binding in portBindings:
         p = region.port[ portId ]
         b = region.newPortBinding( portId, p )
         binding.toSysdb( region, b )

   def fromSysdb( self, region, binding ):
      self.portId = binding.portId
      switchBindings = []
      for switchBinding in binding.portToSwitchInterfaceBinding.values():
         switchBindingModel = PortToSwitchInterfaceBindingModel()
         switchBindingModel.fromSysdb( region, switchBinding )
         switchBindings.append( switchBindingModel )
      self.switchBinding = switchBindings
      hostBindings = []
      for hostBinding in binding.portToHostBinding.values():
         hostBindingModel = PortToHostBindingModel()
         hostBindingModel.fromSysdb( region, hostBinding )
         hostBindings.append( hostBindingModel )
      self.hostBinding = hostBindings

class VmModel( BaseModel ):
   instanceId = Str(
      apiName='id',
      tacName='vmInstanceId',
      description='VM Instance ID' )
   vmHostId = Str(
      apiName='hostId',
      description='VM Host ID' )
   tenantId = Str(
      tacName=None,
      description='VM Tenant ID' )

   def toSysdb( self, vm ):
      for tacName, value in self.getPopulatedModelFields( key='tacName' ):
         # set mutable fields
         if tacName in ( 'vmHostId', ):
            setattr( vm, tacName, value )

   def fromSysdb( self, vm ):
      super().fromSysdb( vm )
      if vm.tenant:
         self.tenantId = vm.tenant.id

class BaremetalModel( BaseModel ):
   instanceId = Str(
      apiName='id',
      tacName='baremetalInstanceId',
      description='Baremetal Instance ID' )
   baremetalHostId = Str(
      apiName='hostId',
      description='Baremetal Host ID' )
   tenantId = Str(
      tacName=None,
      description='Baremetal Tenant ID' )

   def toSysdb( self, baremetal ):
      for tacName, value in self.getPopulatedModelFields( key='tacName' ):
         # set mutable fields
         if tacName in ( 'baremetalHostId', ):
            setattr( baremetal, tacName, value )

   def fromSysdb( self, baremetal ):
      super().fromSysdb( baremetal )
      if baremetal.tenant:
         self.tenantId = baremetal.tenant.id

class RouterModel( BaseModel ):
   instanceId = Str(
      apiName='id',
      tacName='routerInstanceId',
      description='Router Instance ID' )
   routerHostId = Str(
      apiName='hostId',
      description='Router Host ID' )
   tenantId = Str(
      tacName=None,
      description='Router Tenant ID' )

   def toSysdb( self, router ):
      for tacName, value in self.getPopulatedModelFields( key='tacName' ):
         # set mutable fields
         if tacName in ( 'routerHostId', ):
            setattr( router, tacName, value )

   def fromSysdb( self, router ):
      super().fromSysdb( router )
      if router.tenant:
         self.tenantId = router.tenant.id

class DhcpModel( BaseModel ):
   instanceId = Str(
      apiName='id',
      tacName='dhcpInstanceId',
      description='DHCP instance ID' )
   dhcpHostId = Str(
      apiName='hostId',
      description='DHCP host ID' )
   tenantId = Str(
      tacName=None,
      description='DHCP Tenant ID' )

   def toSysdb( self, dhcp ):
      for tacName, value in self.getPopulatedModelFields( key='tacName' ):
         # set mutable fields
         if tacName in ( 'dhcpHostId', ):
            setattr( dhcp, tacName, value )

   def fromSysdb( self, dhcp ):
      super().fromSysdb( dhcp )
      if dhcp.tenant:
         self.tenantId = dhcp.tenant.id
