#!/usr/bin/env python3
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import os, shutil, tempfile, re
import subprocess 
import Tac, Tracing

traceHandle = Tracing.Handle( 'Sswan' )
t0 = traceHandle.trace0

ipsecConfigDir = "/etc/strongswan/"

ipsecConfigFile = "ipsec.conf"
ipsecSecretsFile = "ipsec.secrets"

diffieHillGroup = { '1' : '768' , '14' : '2048', '15' : '3072', '16' : '4096',
                    '19' : '256', '2' : '1024', '5': '1536' }
keyType = { 'pre-share' : 'secret', 'rsa-sig' : 'pubkey' }

authbyDefault = 'pre-share'
ikeLifetimeDef = 28800
ipsecLifetimeDef = 3600
ikeVersionDef = "ikev1"
autoDef = "add"

def isRunning():
   running = False
   sswanPidFile = '/etc/strongswan/ipsec.d/run/charon.pid'
   if os.path.exists( sswanPidFile ):
      with open( sswanPidFile ) as pidFile:
         pid = pidFile.readline()
         
      if pid != '' and os.path.exists( f'/proc/{pid[ :-1 ]}' ):
         running = True
      
   return running      

def connectionCommand( intfName, enable=False ):
   if not isRunning():
      return
   if enable:
      # pylint: disable-next=consider-using-with
      subprocess.Popen( ["sudo", "strongswan", "up", intfName ] )
   else:
      # pylint: disable-next=consider-using-with
      subprocess.Popen( ["sudo", "strongswan", "down", intfName ] )


def secretCommand( parser, **kwargs ): # pylint: disable-msg=W0621
   secretsFilePath = os.path.join( ipsecConfigDir, ipsecSecretsFile )
   shutil.copyfile( secretsFilePath, secretsFilePath + ".bkup" )

   arguments = parser.parse_args()
   # Key is always prefixed with a \, to bypass the argument
   # parser from interpreting the special characters. Hence
   # removes the escape character if present
   if not arguments.delete and arguments.key[0] == "\\":
      arguments.key = arguments.key[1:]
   if arguments.secret == "%default":
      p = re.compile( ' : PSK.*' )
      p2 = re.compile( ' : PSK.*' )
   elif arguments.peerIp:
      p = re.compile( f"{arguments.peerIp} :" )
      p2 = re.compile( f"{arguments.peerIp} :" )

   found = False
   removed = False

   ## remove the old connection configuration and add the new one 
   fh, newPath = tempfile.mkstemp()
   with open( newPath, 'w') as newFile:
      with open( secretsFilePath ) as oldFile:
         for line in oldFile:
            if found and p2.match( line ):
               removed = True
            if not found and p.match( line ):
               found = True
            if not found or ( found and removed ):
               newFile.write( line )
            if found and not removed:
               removed = True
               continue 

      if not arguments.delete and arguments.secret == "%default":
         newFile.write( f' : PSK {arguments.key}\n' )
      else:
         newFile.write( f'{arguments.peerIp} : PSK {arguments.key}\n' )
   os.close( fh )
   os.remove( secretsFilePath )
   shutil.move( newPath, secretsFilePath )
   

def isakmpCommand( parser, **kwargs): # pylint: disable-msg=W0621

   configFilePath = os.path.join( ipsecConfigDir, ipsecConfigFile )
   shutil.copyfile( configFilePath, configFilePath + ".bkup" )

   arguments = parser.parse_args()
   
   if arguments.default:
      p = re.compile( 'conn ' + arguments.default + '$' )
   elif arguments.config:
      p = re.compile( 'conn ' + arguments.config + '$' )
   p2 = re.compile( 'conn ' )
   found = False
   removed = False

   ## remove the old connection configuration and add the new one 
   fh, newPath = tempfile.mkstemp()
   with open( newPath, 'w') as newFile:
      with open( configFilePath ) as oldFile:
         for line in oldFile:
            if found and p2.match( line ):
               removed = True
            if not found and p.match( line ):
               found = True
            if not found or ( found and removed ):
               newFile.write( line )
            if found and not removed:
               continue 

      if not arguments.delete:
         if arguments.default:
            newFile.write( '\nconn %default' + '\n' )
            if arguments.authby: 
               authby = arguments.authby.replace("_", "-")
            else: 
               authby = authbyDefault
            newFile.write( f"\tauthby={keyType[authby]}\n" )
            if arguments.ikelifetime: 
               ikeLifetime = arguments.ikelifetime
            else: 
               ikeLifetime = ikeLifetimeDef

            newFile.write( f"\tikelifetime={int(ikeLifetime)}s\n" )
            #Keylife is a synonym for lifetime. Lifetime : provides the time 
            #a connection is valid before rekeying occurs. A new connection is 
            #created'rekeymargin'
            if arguments.lifetime: 
               ipsecLifetime = arguments.lifetime
            else: 
               ipsecLifetime = ipsecLifetimeDef

            newFile.write( f"\tlifetime={int(ipsecLifetime)}s\n" )
            newFile.write( "\trekeymargin=3m\n" )
            newFile.write( "\tkeyingtries=1\n" )
            newFile.write( "\tmobike=no\n" )

            if arguments.ikeVersion: 
               ikeVersion = arguments.ikeVersion
            else: 
               ikeVersion = ikeVersionDef
            newFile.write( f"\tkeyexchange={ikeVersion}\n" )
           
            newFile.write( f"\tauto={autoDef}\n" )

            if arguments.esp:
               newFile.write( f"\tesp={arguments.esp}\n" )
            if arguments.ike:
               newFile.write( f"\tike={arguments.ike}\n" )
         else:
            newFile.write( f'\nconn {arguments.config}' + '\n' )
            if arguments.ikeVersion:
               newFile.write( f"\tkeyexchange={arguments.ikeVersion}\n" )
            if arguments.authby:
               authby = arguments.authby.replace("_", "-")
               newFile.write( f"\tauthby={keyType[authby]}\n" )
            if arguments.ikelifetime:
               newFile.write( f"\tikelifetime={int(arguments.ikelifetime)}s\n" )
            if arguments.lifetime:
               newFile.write( f"\tlifetime={int(arguments.lifetime)}s\n" )
            newFile.write( f"\tleft={arguments.left}\n" )
            if arguments.leftsubnet:
               newFile.write( f"\tleftsubnet={arguments.leftsubnet}\n" )
            if arguments.localid:
               newFile.write( f"\tleftid={arguments.localid}\n" )
            newFile.write( f"\tright={arguments.right}\n" )
            if arguments.remoteid:
               newFile.write( f"\trightid={arguments.remoteid}\n" )
            newFile.write( f"\tauto={autoDef}\n" )
            if arguments.ike:
               newFile.write( f"\tike={arguments.ike}\n" )
            if arguments.esp:
               newFile.write( f"\tesp={arguments.esp}\n" )
            if arguments.tunnelMode == "IpsecGre":
               newFile.write( "\tleftprotoport=47\n" )
               newFile.write( "\trightprotoport=47\n" )
            elif arguments.tunnelMode == "IpsecVti":
               newFile.write( "\tleftsubnet=0.0.0.0/0\n" )
               newFile.write( "\trightsubnet=0.0.0.0/0\n" )
            if arguments.mark:
               newFile.write( f"\tmark={arguments.mark}\n" )
            if arguments.action:
               newFile.write( f"\tdpdaction={arguments.action}\n" )
               newFile.write( f"\tdpddelay={arguments.interval}\n" )
               newFile.write( f"\tdpdtimeout={arguments.timeout}\n" )
            if arguments.encap:
               newFile.write( "\tforceencaps=yes\n" )
            if arguments.type:
               newFile.write( f"\ttype={arguments.type}\n" )
            if arguments.replayWindowSize:
               newFile.write( "\treplay_window="
                              f"{int(arguments.replayWindowSize)}\n" )
            else:
               newFile.write( "\treplay_window=0\n" )
            if arguments.packetLimit:
               # delete at packetLimit
               newFile.write( "\tlifepackets="
                              f"{int(arguments.packetLimit)}\n" )
               # rekey at 0.75 * packetLimit
               # pylint: disable-next=consider-using-f-string
               newFile.write( "\tmarginpackets=%d\n" %
                              int( 0.25 * arguments.packetLimit ) )
            if arguments.byteLimit:
               newFile.write( f"\tlifebytes={int(arguments.byteLimit)}\n" )
               # pylint: disable-next=consider-using-f-string
               newFile.write( "\tmarginbytes=%d\n" %
                              int( 0.25 * arguments.byteLimit ) )

            newFile.write( "\tmobike=no\n" )


   os.close( fh )
   os.remove( configFilePath )
   shutil.move( newPath, configFilePath )
   

if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   parser.set_defaults(method = isakmpCommand)
   group = parser.add_mutually_exclusive_group()
   group.add_argument( '--default', help="Set the Default profile" )
   group.add_argument( '--config', help="Set the Connection profile" )
   group.add_argument( '--secret', help="Set the Secrets key and value" )
   parser.add_argument( '--delete', help="Delete the Connection", 
                        action='store_true' )
   parser.add_argument( "--name", type=str, help="Connection name" )
   parser.add_argument( "--authby", type=str, help="Authentication mode" )
   parser.add_argument( "--ike", type=str, help="IKE parameteters" )
   parser.add_argument( "--ikeVersion", type=str, help="IKEv1 or IKEv2 mode" )
   parser.add_argument( "--esp", type=str, help="ESP parameteters" )
   parser.add_argument( "--ikelifetime", type=int, help="IKE lifetime" )
   parser.add_argument( "--lifetime", type=int, help="IPSEC SA lifetime" )
   parser.add_argument( "--left", type=str, help="Left tunnel Id" ) 
   parser.add_argument( "--leftsubnet", type=str, help="Left Subnet" ) 
   parser.add_argument( "--right", type=str, help="Right tunnel Id" ) 
   parser.add_argument( "--auto", type=str, help="Initiator/Reactor config" ) 
   parser.add_argument( "--peerIp", type=str, help="Peer IP Address" ) 
   parser.add_argument( "--key", type=str, help="Key value" ) 
   parser.add_argument( "--tunnelMode", type=str, help="GRE/VTI mode" ) 
   parser.add_argument( "--up", type=str, help="Enables Ipsec connection" ) 
   parser.add_argument( "--down", type=str, help="Disable Ipsec connection" ) 
   parser.add_argument( "--mark", type=str, help="packet mark for VTI"  ) 
   parser.add_argument( "--interval", type=str, help="DPD keepalive interval"  ) 
   parser.add_argument( "--timeout", type=str, help="DPD Timeout interval"  ) 
   parser.add_argument( "--action", type=str, help="DPD Action"  ) 
   parser.add_argument( "--encap", type=str, help="Force UDP encapsulation"  ) 
   parser.add_argument( "--type", type=str, help="Ipsec Mode"  ) 
   parser.add_argument( "--localid", type=str, help="Local identification" ) 
   parser.add_argument( "--remoteid", type=str, help="Remote peer identification"  ) 
   parser.add_argument( "--replayWindowSize", type=int,
                        help="IPsec Anti-Replay Window" )
   parser.add_argument( "--packetLimit", type=int, help="IPsec SA packet limit" )
   parser.add_argument( "--byteLimit", type=int, help="IPsec SA byte limit" )

   args = parser.parse_args()
   if args.default == "%default":
      isakmpCommand( parser )
   elif args.config and (args.delete or ( args.left and args.right )):
      isakmpCommand( parser )
   elif args.secret:
      secretCommand( parser )
      # pylint: disable-next=consider-using-with
      subprocess.Popen( [ "sudo", "strongswan", "rereadall" ] )
   elif args.up:
      connectionCommand( args.up, enable=True )
   elif args.down:
      connectionCommand( args.down, enable=False )

   # pylint: disable-next=consider-using-with
   subprocess.Popen( [ "sudo", "strongswan", "update" ] )
   if args.config and args.delete:
      t0(" Stop connection" )
      connectionCommand( args.config, enable=False )
