#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import os
import sys

# This program runs two commands with pipe inbetween, but make the producer use
# a specified pipe fd instead of stdout or stderr. This is because some programs
# change their behavior when stdout is not a terminal which we'd like to avoid.
# The use case is for scp where we would like to transform files it downloads,
# so we can do the following:
#
# $ RunPipe -d 3 "scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null
# <user>@<host>:/etc/passwd /dev/fd/3" "grep root"
# Warning: Permanently added 'localhost' ( RSA ) to the list of known hosts.
# passwd                                         100 % 3848    13.2MB / s   00:00
# root:x:0:0:root:/root:/bin/bash

def exitStatusFromWait( status ):
   if os.WIFEXITED( status ):
      return os.WEXITSTATUS( status )
   else:
      assert os.WIFSIGNALED( status )
      return os.WTERMSIG( status ) + 128

def runPipe():
   parser = argparse.ArgumentParser( description='Run two programs with pipe.' )
   parser.add_argument( '-d', '--pipe-fd', type=int,
                        default=3, help='File descriptor for pipe in producer' )
   parser.add_argument( '-o', '--output-file', type=str,
                        help='output to a file in consumer' )
   parser.add_argument( 'producer', metavar='PRODUCER',
                        help='Producer command run in bash' )
   parser.add_argument( 'consumer', metavar='CONSUMER',
                        help='Consumer command run in bash' )

   args = parser.parse_args()
   if args.pipe_fd < 1 or args.pipe_fd > 1023:
      print( "Error: pipe fd has to be between 1 and 1023", file=sys.stderr )
      sys.exit( 1 )

   # close all FDs except 0-2
   for fd in range( 3, 1024 ):
      try:
         os.close( fd )
      except OSError:
         pass

   output_fd = -1
   if args.output_file:
      try:
         output_fd = os.open( args.output_file,
                              os.O_WRONLY | os.O_CREAT | os.O_TRUNC )
      except OSError as e:
         # pylint: disable-next=consider-using-f-string
         print( "Error: %s" % e.strerror, file=sys.stderr )
         sys.exit( 1 )

   r, w = os .pipe()

   # create consumer
   cPid = os.fork()
   if cPid:
      os.close( r )
      os.close( output_fd )
   else:
      # child
      os.close( w )
      # dup pipe to stdin
      os.dup2( r, sys.stdin.fileno() )
      os.close( r )
      if output_fd >= 0:
         os.dup2( output_fd, sys.stdout.fileno() )
         os.close( output_fd )
      os.execvp( "bash", [ "bash", "-c", "exec " + args.consumer ] )

   # create producer
   pPid = os.fork()
   if pPid:
      os.close( w )
      # parent
   else:
      # dupe w to specified fd
      if w != args.pipe_fd:
         os.dup2( w, args.pipe_fd )
         os.close( w )
      os.execvp( "bash", [ "bash", "-c", "exec " + args.producer ] )

   # wait for all children to exit

   # In bash, the exit status of a pipe command is the same as the
   # final consumer. It's not what we want here. We want to make sure
   # we can detect errors in both producer and consumer. However, we purposefully
   # ignore SIGPIPE in the producer as it can happen normally.
   pStatus = 0
   cStatus = 0

   while True:
      try:
         pid, status = os.wait()
      except OSError:
         # no more children
         break

      _status = exitStatusFromWait( status )
      if pid == pPid:
         pStatus = _status
      else:
         cStatus = _status

   exitStatus = 0
   if cStatus or pStatus:
      exitStatus = 1

   if exitStatus and args.output_file:
      # clean up output file
      try:
         os.unlink( args.output_file )
      except OSError:
         pass
   sys.exit( exitStatus )

if __name__ == '__main__':
   try:
      runPipe()
   except KeyboardInterrupt:
      pass

