# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import socket
import BothTrace
import Server
import UpnpMsg_pb2
import google.protobuf as proto
import struct
import Tracing
import Tac
import os
import QuickTrace

t0 = Tracing.trace0
bt0 = BothTrace.tracef0
bvar = BothTrace.Var
__defaultTraceHandle__ = Tracing.Handle( 'IgdUpnpMsg' )

qv = QuickTrace.Var
qt0 = QuickTrace.trace0

class UpnpMsgHandler:
   def __init__( self, vrfName, igdPortMapConfig, portMapStatus, igdVrfStatus ):
      self.vrfName = vrfName
      self.igdPortMapConfig = igdPortMapConfig
      self.portMapStatus = portMapStatus
      self.igdVrfStatus = igdVrfStatus
      if vrfName not in self.igdPortMapConfig.upnpPortMapVrfConfig:
         self.igdPortMapConfig.upnpPortMapVrfConfig.newMember( vrfName )
      self.upnpPortMapVrfConfig = \
            self.igdPortMapConfig.upnpPortMapVrfConfig[ vrfName ]
      self.upnpPortMapVrfConfig.natProfileName = self.igdVrfStatus.natInfo.profile
      self.upnpPortMapVrfConfig.natAclName = self.igdVrfStatus.natInfo.acl
      # A list or requests that made it through some initial checks and
      # waiting for status updates from the Nat agent.
      # This provides a reverse lookup to connections.
      self.request_ = {}

   def __del__( self ):
      pass

   def hashKey( self, externalPort, ipProto ):
      return tuple( [ externalPort, ipProto ] )

   def getPendingRequest( self, portMapKey ):
      hk = self.hashKey( portMapKey.externalPort, portMapKey.protocol )
      return self.request_.get( hk )

   def getCounters( self ):
      return self.igdVrfStatus.counters

   REQ_COMPLETED = 0
   REQ_PENDING = 1
   REQ_CONFLICT = 2

   def handleAddPortMapReq( self, addPortMapReq, fileno ):
      cnt = self.getCounters()
      cnt.addRequestsTotal = cnt.addRequestsTotal + 1

      # Decode the addPortMapReq message
      internalIp = Tac.Value( "Arnet::IpAddr",
            socket.ntohl( addPortMapReq.private_ip ) )
      internalPort = Tac.Value( "Arnet::Port", addPortMapReq.private_port )
      ipProto = Tac.Value( "Arnet::IpProto", addPortMapReq.protocol )
      externalIp = self.igdVrfStatus.externalIpv4Addr
      externalPort = Tac.Value( "Arnet::Port", addPortMapReq.public_port )
      portMapKey = Tac.Value( "IgdUpnpShared::UpnpPortMapKey",
            externalIp, externalPort, ipProto )

      hk = self.hashKey( addPortMapReq.public_port, addPortMapReq.protocol )

      # there is a pending request on the same key
      if self.request_.get( hk ):
         cnt.addConflicts = cnt.addConflicts + 1
         return self.REQ_CONFLICT

      # Check if the port mapping already exists in portMapConfig and reject
      # mappings from different clients or many-to-one from the same client
      cfg = self.upnpPortMapVrfConfig.portMapConfig.get( portMapKey )
      if cfg and ( cfg.internalIp != internalIp.stringValue or
            cfg.internalPort != internalPort.value ):
         cnt.addConflicts = cnt.addConflicts + 1
         return self.REQ_CONFLICT

      # Check the status only if there is a config already from the same client
      # There is an edge case when the config is gone b/c of a previous delete
      # request but the status is still pending that we ignore
      if cfg:
         vrfPortMapStatus = self.portMapStatus.upnpPortMapVrfStatus.get(
               self.vrfName )
         if vrfPortMapStatus:
            portMapStatus = vrfPortMapStatus.portMapStatus.get( portMapKey )
            t0( 'Existing status: ', portMapStatus )
            if portMapStatus and portMapStatus == "portMapCreated":
               return self.REQ_COMPLETED

      upnpConn = Tac.Value( "IgdUpnpShared::UpnpPortMap", portMapKey,
            internalIp, internalPort, addPortMapReq.lease_duration )

      # Update the sysdb
      if not cfg:
         self.upnpPortMapVrfConfig.portMapConfig.addMember( upnpConn )

      self.request_[ hk ] = fileno
      return self.REQ_PENDING

   def handleDelPortMapReq( self, delPortMapReq, fileno ):
      cnt = self.getCounters()
      cnt.delRequestsTotal = cnt.delRequestsTotal + 1

      # Decode delPortMapReq message
      externalIp = self.igdVrfStatus.externalIpv4Addr
      externalPort = Tac.Value( "Arnet::Port", delPortMapReq.public_port )
      ipProto = Tac.Value( "Arnet::IpProto", delPortMapReq.protocol )
      internalIp = Tac.Value( "Arnet::IpAddr",
            socket.ntohl( delPortMapReq.private_ip ) )
      portMapKey = Tac.Value( "IgdUpnpShared::UpnpPortMapKey",
            externalIp, externalPort, ipProto )

      # there is a pending request on the same key
      hk = self.hashKey( delPortMapReq.public_port, delPortMapReq.protocol )

      if self.request_.get( hk ):
         return False

      cfg = self.upnpPortMapVrfConfig.portMapConfig.get( portMapKey )
      if not cfg:
         return False

      if cfg.internalIp != internalIp.stringValue:
         return False

      # Delete the req from the portMapConfig
      del self.upnpPortMapVrfConfig.portMapConfig[ portMapKey ]
      self.request_[ hk ] = fileno
      return True

   def handleGetSpecificPortMappingEntryReq( self,
         getSpecificPortMappingEntryReq, fileno ):
      cnt = self.getCounters()
      cnt.getSpecificMappingRequestsTotal = cnt.getSpecificMappingRequestsTotal + 1

      externalIp = self.igdVrfStatus.externalIpv4Addr
      externalPort = Tac.Value( "Arnet::Port",
            getSpecificPortMappingEntryReq.public_port )
      ipProto = Tac.Value( "Arnet::IpProto",
            getSpecificPortMappingEntryReq.protocol )

      portMapKey = Tac.Value( "IgdUpnpShared::UpnpPortMapKey",
            externalIp, externalPort, ipProto )

      msg = UpnpMsg_pb2.UpnpMessage() # pylint: disable=no-member
      #pylint: disable=E1101
      msg.getSpecificPortMappingEntryRsp.status = UpnpMsg_pb2.Status.FAIL

      # there is a pending request on the same key
      if self.getPendingRequest( portMapKey ):
         t0( 'GetSpecificPortMapping pending request: %d' %
               getSpecificPortMappingEntryReq.public_port )
         qt0( 'GetSpecificPortMapping pending request ', qv( externalPort ) )
         cnt.getSpecificMappingFailed = cnt.getSpecificMappingFailed + 1
         return msg

      cfg = self.upnpPortMapVrfConfig.portMapConfig.get( portMapKey )
      if not cfg:
         t0( 'GetSpecificPortMapping config not found: %d' %
               getSpecificPortMappingEntryReq.public_port )
         qt0( 'GetSpecificPortMapping config not found ', qv( externalPort ) )
         cnt.getSpecificMappingFailed = cnt.getSpecificMappingFailed + 1
         return msg

      #pylint: disable=E1101
      msg.getSpecificPortMappingEntryRsp.private_port = cfg.internalPort
      msg.getSpecificPortMappingEntryRsp.private_ip = \
         struct.unpack( "I", socket.inet_aton( cfg.internalIp ) )[ 0 ]
      msg.getSpecificPortMappingEntryRsp.lease_duration = int( cfg.leaseTime )
      msg.getSpecificPortMappingEntryRsp.status = UpnpMsg_pb2.Status.SUCCESS
      qt0( 'GetSpecificPortMapping config found ', qv( externalPort ) )
      cnt.getSpecificMappingRequestsSuccess = \
         cnt.getSpecificMappingRequestsSuccess + 1
      return msg

   def closeRequest( self, externaPort, ipProto, fileno ):
      hk = self.hashKey( externaPort, ipProto )
      # the request may not be found because it hasn't been received
      # or was rejected immediately
      if self.request_.get( hk ):
         del self.request_[ hk ]

   def deleteConfig( self, portMapKey, privateIp ):
      cfg = self.upnpPortMapVrfConfig.portMapConfig.get( portMapKey )
      if not cfg:
         return

      internalIp = Tac.Value( "Arnet::IpAddr", socket.ntohl( privateIp ) )

      # another client mapping, do nothing
      if cfg.internalIp != internalIp.stringValue:
         return

      del self.upnpPortMapVrfConfig.portMapConfig[ portMapKey ]

class UpnpdBuffer:
   def __init__( self ):
      self.buffer_ = bytearray()
      self.size_ = 0

   def size( self ):
      return self.size_

   def read( self, size ):
      data = self.buffer_[ : size ]
      self.buffer_[ : size ] = b""
      self.size_ -= size
      return data

   def write( self, data ):
      self.buffer_ += data
      self.size_ += len( data )

class UpnpdSocketSession( Server.Session ):
   def __init__( self, clientFd, server ):
      Server.Session.__init__( self, clientFd, server )
      self.clientFd = clientFd
      self.upnpMsgHandler = server.upnpMsgHandler
      self.upnpdBuffer = UpnpdBuffer()
      self.delimiterLen = 4
      self.externalIp = server.igdVrfStatus.externalIpv4Addr
      self.protocol = 0
      self.publicPort = 0
      self.privateIp = 0
      self.msgType = ""
      self.responseSent = False
      self.sessionClosed = False

   def getPortMapKey( self ):
      externalPort = Tac.Value( "Arnet::Port", self.publicPort )
      ipProto = Tac.Value( "Arnet::IpProto", self.protocol )
      return Tac.Value( "IgdUpnpShared::UpnpPortMapKey",
            self.externalIp, externalPort, ipProto )

   def parseUpnpdMessage( self, data ):
      upnpMessage = UpnpMsg_pb2.UpnpMessage() # pylint: disable=no-member
      upnpMessage.ParseFromString( data )
      t0( "UpnpMessage:", bvar( upnpMessage ) )
      return upnpMessage

   def sendMsg( self, msg ):
      try:
         size = socket.htonl( msg.ByteSize() )
         self.socket_.sendall( struct.pack( "I", size ) + msg.SerializeToString() )
         self.responseSent = True
      except OSError as e:
         t0( "sendMsg exception: %s" % os.strerror( e.errno ) )
      self.cleanupState()

   def sendAddReqRsp( self, status ):
      #pylint: disable=E1101
      msg = UpnpMsg_pb2.UpnpMessage()
      msg.addPortMapRsp.protocol = self.protocol
      msg.addPortMapRsp.public_port = self.publicPort
      msg.addPortMapRsp.status = status
      if status == UpnpMsg_pb2.Status.FAIL:
         qt0( 'AddReq failed: ', qv( self.publicPort ) )
      else:
         cnt = self.upnpMsgHandler.getCounters()
         cnt.addSuccess = cnt.addSuccess + 1
      self.sendMsg( msg )

   def sendDelReqRsp( self, status ):
      #pylint: disable=E1101
      msg = UpnpMsg_pb2.UpnpMessage()
      msg.delPortMapRsp.protocol = self.protocol
      msg.delPortMapRsp.public_port = self.publicPort
      msg.delPortMapRsp.status = status
      cnt = self.upnpMsgHandler.getCounters()
      if status == UpnpMsg_pb2.Status.FAIL:
         qt0( 'DelReq failed: ', qv( self.publicPort ) )
         cnt.delFailed = cnt.delFailed + 1
      else:
         cnt.delSuccess = cnt.delSuccess + 1
      self.sendMsg( msg )

   def replyPortMapStatus( self, status ):
      if self.sessionClosed:
         return
      if self.msgType == 'addPortMapReq':
         if status == "portMapInitAdd":
            return
         elif status == "portMapCreated":
            # pylint: disable=no-member
            self.sendAddReqRsp( UpnpMsg_pb2.Status.SUCCESS )
         elif status in ( "portMapRefused", "portMapCleared" ):
            # delete the config
            portMapKey = self.getPortMapKey()
            t0( 'Delete the refused port mapping config for the port: ',
                 self.publicPort )
            del self.upnpMsgHandler.upnpPortMapVrfConfig.portMapConfig[ portMapKey ]
            cnt = self.upnpMsgHandler.getCounters()
            if status == "portMapRefused":
               cnt.addRefused = cnt.addRefused + 1
            else:
               cnt.portMappingClearedLocally = cnt.portMappingClearedLocally + 1
            self.sendAddReqRsp( UpnpMsg_pb2.Status.FAIL ) # pylint: disable=no-member
         else:
            t0( 'unknown addPortMapStatus ', status )
      elif self.msgType == 'delPortMapReq':
         # for delete requests we wait until the status is deleted
         if not status:
            # pylint: disable=no-member
            self.sendDelReqRsp( UpnpMsg_pb2.Status.SUCCESS )
         elif status == "portMapInitDel":
            pass
         else:
            t0( 'unknown delPortMapStatus ', status )

   def _handleInput( self, data ):
      # a request has been created
      if self.msgType:
         t0( "unexpected input" )
         return

      #pylint: disable=E1101
      if not data:
         return

      self.upnpdBuffer.write( data )
      if self.upnpdBuffer.size() < self.delimiterLen:
         return

      msgLen = struct.unpack( "I",
             self.upnpdBuffer.read( self.delimiterLen ) )[ 0 ]
      msgLen = socket.ntohl( msgLen )
      if self.upnpdBuffer.size() < msgLen:
         return

      data = self.upnpdBuffer.read( msgLen )
      try:
         upnpMessage = self.parseUpnpdMessage( data )
         self.msgType = upnpMessage.WhichOneof( 'msg' )
         if self.msgType == 'addPortMapReq':
            self.protocol = upnpMessage.addPortMapReq.protocol
            self.publicPort = upnpMessage.addPortMapReq.public_port
            self.privateIp = upnpMessage.addPortMapReq.private_ip

            s = self.upnpMsgHandler.handleAddPortMapReq(
                  upnpMessage.addPortMapReq, self.socket_.fileno() )
            if s == UpnpMsgHandler.REQ_CONFLICT:
               self.sendAddReqRsp( UpnpMsg_pb2.Status.FAIL )
            elif s == UpnpMsgHandler.REQ_COMPLETED:
               self.sendAddReqRsp( UpnpMsg_pb2.Status.SUCCESS )

         elif self.msgType == 'delPortMapReq':
            self.protocol = upnpMessage.delPortMapReq.protocol
            self.publicPort = upnpMessage.delPortMapReq.public_port
            if not self.upnpMsgHandler.handleDelPortMapReq(
                  upnpMessage.delPortMapReq, self.socket_.fileno() ):
               self.sendDelReqRsp( UpnpMsg_pb2.Status.FAIL )

         elif self.msgType == 'getSpecificPortMappingEntryReq':
            response = self.upnpMsgHandler.handleGetSpecificPortMappingEntryReq(
                  upnpMessage.getSpecificPortMappingEntryReq,
                  self.socket_.fileno() )
            self.sendMsg( response )

      except proto.message.DecodeError:
         bt0( "Exception while decoding message" )
         return

   def cleanupState( self ):
      if self.sessionClosed:
         t0( "session already closed" )
         return
      self.sessionClosed = True
      # if there was a port mapping request message that hasn't been replied to
      # for whatever reasons, we consider it as a failed transaction
      if self.msgType == 'addPortMapReq' and not self.responseSent:
         portMapKey = self.getPortMapKey()
         self.upnpMsgHandler.deleteConfig( portMapKey, self.privateIp )
         cnt = self.upnpMsgHandler.getCounters()
         cnt.addPrematureSessionClosure = cnt.addPrematureSessionClosure + 1

      # if a request has been received, close it
      self.upnpMsgHandler.closeRequest( self.publicPort,
            self.protocol, self.socket_.fileno() )


class UpnpdSocket( Server.Server ):
   def __init__( self, vrfName, portMapConfig, portMapStatus, igdVrfStatus ):
      self.socketPath = "/var/run/miniupnpd-%s.sock" % vrfName
      Server.Server.__init__( self,
            name='UPNP socket',
            port=self.socketPath,
            domain='unix' )
      self.portMapConfig = portMapConfig
      self.portMapStatus = portMapStatus
      self.igdVrfStatus = igdVrfStatus
      self.session_ = {}
      self.upnpMsgHandler = UpnpMsgHandler( vrfName, self.portMapConfig,
            self.portMapStatus, igdVrfStatus )

   def onAccept( self, newSocket ):
      t0( "New session" )
      self.session_[ newSocket.fileno() ] = UpnpdSocketSession( newSocket, self )

   def closeSession( self, session ):
      t0( ' session close', session )
      fileno = session.socket_.fileno()
      try:
         upnpSession = self.session_[ fileno ]
         if upnpSession:
            upnpSession.cleanupState()
            del self.session_[ fileno ]
      except KeyError:
         pass

   def processPortMapStatus( self, portMapKey, status ):
      # find connection by the key
      fileno = self.upnpMsgHandler.getPendingRequest( portMapKey )

      if not fileno:
         t0( 'unable to find pending request for the portMapKey', portMapKey )
         # port mapping could be cleared after the request had been processed
         if status == "portMapCleared":
            cnt = self.upnpMsgHandler.getCounters()
            cnt.portMappingClearedLocally = cnt.portMappingClearedLocally + 1
            # delete the config
            del self.upnpMsgHandler.upnpPortMapVrfConfig.portMapConfig[ portMapKey ]
         return

      session = self.session_.get( fileno )
      if not session:
         t0( 'unable to find session for the portMapKey', portMapKey )
         return
      session.replyPortMapStatus( status )
