#!/usr/bin/env python3
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# Monitor xinetd and restart it if ssh isn't being listened on
#
# - hzhong@arista.com 03/27/2020
import argparse
import datetime
import glob
import time
import Tac
import Tracing

t0 = Tracing.trace0
t1 = Tracing.trace1

defaultVrfName = 'default'

def isSshOpen( nsName, v6=False ):
   cmd = [ 'ip', 'netns', 'exec', nsName, 'lsof', '-i', 'tcp:22' ]
   try:
      output = Tac.run( cmd, asRoot=True,
                        stdout=Tac.CAPTURE )
   except Tac.SystemCommandError:
      output = ''

   # Do we have necessary listening ports in the output?
   # We expect both IPv4 and IPv6.
   for line in output.splitlines():
      fields = line.split()
      if '(LISTEN)' in fields:
         if not v6 and 'IPv4' in fields:
            t0( "SSH v4 is enabled in netns", nsName )
            return True
         elif v6 and 'IPv6' in fields:
            t0( "SSH v6 is enabled in netns", nsName )
            return True
   return False

def getEnabledNamespaces( v6=False ):
   # find all namespaces where SSH is enabled
   # pylint: disable-next=consider-using-f-string
   basename = "/etc/xinetd.d/ssh%s" % ( "6" if v6 else "" )
   sshVrfFiles = glob.glob( basename + '-*' )
   namespaces = set()
   output = Tac.run( [ 'egrep', '-H', 'disable += +no',
                       basename ] + sshVrfFiles,
                     stdout=Tac.CAPTURE,
                     stderr=Tac.DISCARD,
                     ignoreReturnCode=True )
   t1( "SSH xinetd files:", output )
   for line in output.splitlines():
      fname = line.split( ':' )[ 0 ]
      if fname == basename:
         namespaces.add( defaultVrfName ) # default
      elif fname.startswith( basename + '-' ):
         vrf = fname[ ( len( basename ) + 1 ): ]
         namespaces.add( 'ns-' + vrf )

   return namespaces

def syslog( message, prio='warn' ):
   Tac.run( [ "logger", "-i", "-t", "XinetdMonitor", "-p", "local4." + prio,
              message ] )

def restartXinetd():
   Tac.run( [ 'service', 'xinetd', 'restart' ],
            asRoot=True )

def logXinetdInfo( ns ):
   with open( "/var/log/XinetdMonitor.log", "a" ) as f:
      # first log the netstat info
      print( datetime.datetime.today(), file=f )
      print( "SSH is not active in namespace:", ns, file=f )
      f.flush()
      Tac.run( [ 'ip', 'netns', 'exec', ns,
                 'bash', '-c', 'netstat -tulp | grep ssh' ],
               stdout=f, stderr=f,
               asRoot=True, ignoreReturnCode=True )

      pid = Tac.run( [ 'pidof', 'xinetd' ],
                     stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                     ignoreReturnCode=True ).strip()
      if pid == '':
         print( "no xinetd running\n", file=f )
      elif ' ' in pid:
         # pylint: disable-next=consider-using-f-string
         print( "multiple xinetd running: %s\n" % pid, file=f )
      else:
         Tac.run( [ 'lsof', '-p', pid ],
                  stdout=f, stderr=f,
                  ignoreReturnCode=True,
                  asRoot=True )

def getSshNotOpen( v6=False ):
   namespaces = getEnabledNamespaces( v6=v6 )
   t0( "SSH is enabled in:", ' '.join( sorted( namespaces ) ) )
   for ns in sorted( namespaces ):
      if not isSshOpen( ns, v6=v6 ):
         return ns

   return None

def checkSsh( v6=False ):
   ns1 = getSshNotOpen( v6=v6 )
   if ns1:
      # to avoid unnecessarily restart xinetd due to transient issue,
      # let's wait for 10 seconds and try again
      time.sleep( 10 )
      ns2 = getSshNotOpen( v6=v6 )
      if ns2 == ns1 :
         # pylint: disable-next=consider-using-f-string
         syslog( "xinetd is not listening on SSH %s in namespace %s, restarting" %
                 ( "IPv6" if v6 else "IPv4", ns1 ) )
         # log more info
         logXinetdInfo( ns1 )
         restartXinetd()
      else:
         # pylint: disable-next=consider-using-f-string
         syslog( "inconsistent result within 10 secs (%s), skipping" % ns1 )
         logXinetdInfo( ns1 )
         if ns2:
            logXinetdInfo( ns2 )

if __name__ == '__main__':
   # Usage: XnetMonitor.py [ -6 ]
   parser = argparse.ArgumentParser( description='Xinetd monitor' )
   parser.add_argument( '-6', dest='v6', action="store_true",
                        help="monitor IPv6 (default IPv4)" )
   args = parser.parse_args()
   checkSsh( v6=args.v6 )
