#!/usr/bin/env python3
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

"""Python bindings for Afetch.

A callback-based Afetch client is provided, allowing one
to make HTTP(S) requests and receive updates via callback when
complete as well as receive callbacks for further updates.
"""

from functools import partial
import os
import shutil
import Tac
import Tracing
import requests
from requests.exceptions import RequestException

t1 = Tracing.trace5
t5 = Tracing.trace5
t8 = Tracing.trace8

Error = Tac.Type( 'Afetch::Error' )
Method = Tac.Type( 'Afetch::Method' )
RequestState = Tac.Type( 'Afetch::RequestState' )
Scheme = Tac.Type( 'Afetch::Scheme' )
AuthType = Tac.Type( 'Afetch::HttpAuthenticationType' )

# HTTP method name shortcuts
CONNECT = Method.CONNECT
DELETE = Method.DELETE
GET = Method.GET
HEAD = Method.HEAD
OPTIONS = Method.OPTIONS
POST = Method.POST
PUT = Method.PUT
TRACE = Method.TRACE
PATCH = Method.PATCH

# Attributes to ignore when making dictionaries of counters and timers
TAC_ATTR_IGNORE = frozenset(
   ( 'entity', 'fullName', 'isNondestructing', 'name', 'parent', 'parentAttrName' ) )

def _valueToDict( value ):
   result = {}
   for key in value.attributes:
      if key not in TAC_ATTR_IGNORE:
         result[ key ] = getattr( value, key )
   return result


class ResponseReactor( Tac.Notifiee ):
   """A TACC notifiee for Afetch responses, used to initiate client callbacks."""

   notifierTypeName = 'Afetch::Response'

   def __init__( self, afetchResponseDir, callback ):
      Tac.Notifiee.__init__( self, afetchResponseDir )
      self.afetchResponseDir_ = afetchResponseDir
      self.callback_ = callback

   @Tac.handler( 'requestState' )
   def handleRequestState( self ):
      if self.notifier_.finished():
         # Only notify when we've actually changed the result, which will be
         # on the first response and then when any important response field changes
         if self.notifier_.lastChangeTime >= self.notifier_.lastUpdateTime:
            self.callback_( self.notifier_ )


class Client:
   """The Python Afetch Client interface."""

   def __init__( self, baseDir=None ):
      if baseDir:
         self.root = baseDir.newEntity( 'Afetch::Root', 'afetchRoot' )
      else:
         self.root = Tac.newInstance( 'Afetch::Root', 'afetch' )
      self.aresolveSm = self.root.aresolveSm
      self.afetchSm = self.root.afetchSm
      self.requestDir = self.root.afetchRequest
      self.responseDir = self.root.afetchResponse
      self.callbacks_ = {}
      # Previously we had _handleResponse as an instance method, and when passing it
      # into the collectionChangeReactor it would create a bound method. Thus, the
      # reactor had a reference to the Client object and the Client object also has a
      # reference to the reactor. This creates a circular dependency, and the Client
      # will not be destroyed until the GC happens to detect the circular dependency,
      # leading to us "leaking" the AfetchSm.
      # To avoid this circular dependency, we can instead make _handleResponse into
      # a static method which takes the callbacks as a parameter, and use `partial`
      # to create a new function which always has our callback dict as the value for
      # that parameter. We can then pass this new function into the reactor, and it
      # won't have a reference to the Client object, only the callback dict, and thus
      # this eliminates the circular dependency and allows the Client to be destroyed
      # once the user deletes their reference to it.
      # See https://docs.python.org/3/library/functools.html#functools.partial for a
      # better explanation of `partial`.
      self.reactor_ = Tac.collectionChangeReactor(
         self.responseDir.response,
         ResponseReactor,
         reactorArgs=( partial( Client._handleResponse, self.callbacks_ ), ) )

   @property
   def numRequestsRunning( self ):
      return len( [ r for r in self.responseDir.response
                    if r.requestState > RequestState.stateNotInit ] )

   @property
   def numRequests( self ):
      return len( self.requestDir.request )

   def request( self, method, uri, callback=None, headers=None, running=False ):
      """Builds an Afetch::Request object to call 'uri' with HTTP 'method'.

      By default, this method starts the request running by placing it
      in the Afetch request collection; use running=False as a keyword
      argument to not start the request straight away (useful for POST
      requests to set a body).

      The response to this request, once running, will be received by your callback.

      Args:
         method: str, an enum value from Afetch.Method (e.g., Afetch.Method.GET)
         uri: str, the URI to retrieve
         callback: A callable, the callback for this request (key); if not specified,
            the response will only be available by calling getResponse for the key
            and polling the response object manually.
         headers: A dict, HTTP request headers
         running: bool, if True, start the request immediately

      Returns:
        An Afetch::Request object. Add it to the Afetch request
        collection to start the request if running=False was
        passed. See module docstring for details on the Request and
        Response objects.
      """
      key = Tac.Value( 'Afetch::RequestKey', method, uri )
      request = Tac.Value( 'Afetch::Request', key )
      if headers:
         for hkey in headers:
            request.header[ hkey ] = headers[ hkey ]
      self.callbacks_[ Tac.hashOf( key ) ] = callback

      if running:
         self.requestDir.addRequest( request )
      return request

   def connect( self, uri, **kwargs ):
      return self.request( CONNECT, uri, **kwargs )

   def get( self, uri, **kwargs ):
      return self.request( GET, uri, **kwargs )

   def head( self, uri, **kwargs ):
      return self.request( HEAD, uri, **kwargs )

   def post( self, uri, **kwargs ):
      return self.request( POST, uri, **kwargs )

   def put( self, uri, **kwargs ):
      return self.request( PUT, uri, **kwargs )

   def delete( self, uri, **kwargs ):
      return self.request( DELETE, uri, **kwargs )

   def options( self, uri, **kwargs ):
      return self.request( OPTIONS, uri, **kwargs )

   def trace( self, uri, **kwargs ):
      return self.request( TRACE, uri, **kwargs )

   def patch( self, uri, **kwargs ):
      return self.request( PATCH, uri, **kwargs )

   def getResponse( self, requestKey ):
      """Returns the current value of the reponse for request."""
      return self.responseDir.response[ requestKey ]

   def requestCounters( self, requestKey ):
      """Returns the current dictionary of response counters for a request."""
      response = self.responseDir.response[ requestKey ]
      return _valueToDict( response.counter )

   def requestTimers( self, requestKey ):
      """Returns the current dictionary of response timers for a request."""
      response = self.responseDir.response[ requestKey ]
      return _valueToDict( response.timer )

   def start( self, request ):
      """Starts a request (such as that created by makeRequest) running."""
      self.requestDir.addRequest( request )

   def stop( self, requestKey ):
      """Cancels the request, also deleting the response object immediately."""
      self.deleteRequest( requestKey )

   def deleteRequest( self, requestKey ):
      """Erases the request/response object pair for the given request."""
      del self.requestDir.request[ requestKey ]

   # _handleResponse is now a static method, taking the callback dict as a parameter.
   # This avoids a circular dependency from the Client to its reactor_ which
   # prevented the Client from being destroyed.
   # See the explanation in Client.__init__ for more info.
   @staticmethod
   def _handleResponse( callbacks, response ):
      """Select the appropriate callback for the response and call it."""
      callback = callbacks.get( Tac.hashOf( response.requestKey ) )
      if callback:
         callback( response )

class SimpleHttpGetRequest:
   """ A wrapper class for requests.get() that's simpler than Afetch.Client
   and offers the same benefits i.e, retry for failures, reading large files
   in chunks etc."""
   RETRY_TIME = 60
   REQUEST_TIMEOUT = 5
   TEMP_FILE_EXT = '.tmp'

   def __init__( self, url, callback=None, outputPath='', **kwargs ):
      self.url = url
      self.cb = callback
      self.outputPath = outputPath
      self.kwargs = kwargs
      self.response = None
      self.complete = False
      self.file = None
      self.streamClock = Tac.ClockNotifiee( self.readByteStream,
                                            timeMin=Tac.endOfTime )
      self.retryClock = Tac.ClockNotifiee( self.start, timeMin=Tac.endOfTime )

   def retry( self ):
      self.stop()
      self.retryClock.timeMin = Tac.now() + self.RETRY_TIME
      if self.cb:
         self.cb( None )

   def start( self ):
      t5( "GET", self.url )
      # TODO : 'verify' should be True by default and if possible, point to a cert
      verify = self.kwargs.pop( 'verify', False )
      stream = self.kwargs.pop( 'stream', False ) or bool( self.outputPath )
      timeout = self.kwargs.pop( 'timeout', self.REQUEST_TIMEOUT )

      try:
         self.response = requests.get( self.url, verify=verify, stream=stream,
                                       timeout=timeout, **self.kwargs )
      except RequestException as e:
         t1( "FAILED with %s" % e )
         print( e )
         self.retry()
      else:
         if self.response.status_code // 100 != 2:
            t1( "FAILED with code=%d" % self.response.status_code )
            self.retry()
         else:
            if stream:
               # If there's an output path, push it to the next activity cycle
               # to read in chunks since we don't know the size of the
               # downloaded file
               t5( "Downloading the file" )
               targetDir = os.path.dirname( self.outputPath )
               tempFileName = self.outputPath + self.TEMP_FILE_EXT
               try:
                  if not os.path.exists( targetDir ):
                     os.makedirs( targetDir )
                  self.file = open( tempFileName, 'w+b' ) 
               except OSError as e:
                  t1( "FAILED with %s" % e )
                  self.retry()
               else:
                  self.streamClock.timeMin = Tac.now()
            else:
               t1( "COMPLETED with code=%d" % self.response.status_code )
               self.stop( complete=True )

   def stop( self, complete=False ):
      """Typically called when request is complete"""
      t5( "stop() -", self.url )
      self.streamClock.timeMin = Tac.endOfTime
      self.retryClock.timeMin = Tac.endOfTime
      if self.file:
         self.file.close()
      if self.response is not None:
         self.response.close()
      if complete:
         self.complete = True
         if self.cb:
            self.cb( self.response )

   def readByteStream( self ):
      """Meant to be called only by the Clock notifiee to read the GET response
      content in chunks"""
      t8( "readByteStream() -", self.url.split( '/' )[ -1 ] )
      assert not self.complete
      try:
         iters = 0
         for buf in self.response.iter_content( chunk_size=1024 * 1024 ):
            self.file.write( buf )
            iters += 1
            if iters == 10:
               # Break after reading 10 MB
               t8( "Read a chunk" )
               self.streamClock.timeMin = Tac.now()
               return
      except ( RequestException, OSError ) as e:
         t1( "Exception while receiving data from controller:", e )
         print( e ) # To capture it in agent log when tracing is not enabled
         self.retry()
         return

      # Download complete. Move the temp file to the desired file.
      try:
         self.file.flush()
         os.fsync( self.file.fileno() )
         self.file.close()
      except OSError:
         t1( "Exception while closing the file", e )
         self.retry()
         return

      t8( "moving", self.file.name, "to", self.outputPath )
      try:
         shutil.move( self.file.name, self.outputPath )
      except shutil.Error as e:
         t1( "Exception while rewriting SWI file:", e )
         self.retry()
         try:
            os.remove( self.file.name )
         except OSError:
            print( "Removing %s failed" % self.file.name )
         return

      self.stop( complete=True )
