#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
from CliMode.Rpki import ( RpkiCacheMode,
                           RpkiTransportTcpMode,
                           RpkiTransportTlsMode,
                           RpkiOriginValidationBaseMode )
import CliSave
from CliSavePlugin.RoutingBgpCliSave import ( RouterBgpBaseConfigMode,
                                              neighborRpkiHook )
from IpLibConsts import DEFAULT_VRF
from RouteMapLib import isAsdotConfigured

ovMethodEnum = Tac.Type( "Routing::Bgp::Rpki::OriginValidationMethod" )

class RpkiCacheConfigCliSaveMode( RpkiCacheMode, CliSave.Mode ):
   def __init__( self, param ):
      RpkiCacheMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

RouterBgpBaseConfigMode.addChildMode( RpkiCacheConfigCliSaveMode )
RpkiCacheConfigCliSaveMode.addCommandSequence( 'Bgp.rpki.config' )

class RpkiTransportTcpCliSaveMode( RpkiTransportTcpMode,
                                        CliSave.Mode ):
   def __init__( self, param ):
      RpkiTransportTcpMode.__init__( self )
      CliSave.Mode.__init__( self, param )

RpkiCacheConfigCliSaveMode.addChildMode( RpkiTransportTcpCliSaveMode )
RpkiTransportTcpCliSaveMode.addCommandSequence(
      'Bgp.rpki.transport-tcp.config' )

class RpkiTransportTlsCliSaveMode( RpkiTransportTlsMode, CliSave.Mode ):
   def __init__( self, param ):
      RpkiTransportTlsMode.__init__( self )
      CliSave.Mode.__init__( self, param )

RpkiCacheConfigCliSaveMode.addChildMode( RpkiTransportTlsCliSaveMode )
RpkiTransportTlsCliSaveMode.addCommandSequence(
      'Bgp.rpki.transport-tls.config' )

#------------------------------------------------------------------------------------
# bgp rpki origin-validation mode
#------------------------------------------------------------------------------------
class RpkiOriginValidationCliSaveMode( RpkiOriginValidationBaseMode, CliSave.Mode ):
   def __init__( self, param ):
      RpkiOriginValidationBaseMode.__init__( self )
      CliSave.Mode.__init__( self, param )

RouterBgpBaseConfigMode.addChildMode( RpkiOriginValidationCliSaveMode )

RpkiOriginValidationCliSaveMode.addCommandSequence( 'Bgp.rpki.origin-validation' )

def saveRpkiOriginValidationConfig( bgpConfig, saveAll=False ):
   cmds = []

   if bgpConfig.rpkiEbgpOvMethod.isSet:
      method = bgpConfig.rpkiEbgpOvMethod.value
      if method == ovMethodEnum.ovLocal:
         cmds.append( 'ebgp local' )
      elif method == ovMethodEnum.ovCommunity:
         cmds.append( 'ebgp community' )
      elif method == ovMethodEnum.ovPreferCommunity:
         cmds.append( 'ebgp prefer-community' )
      else:
         assert False, "Unrecognized method %s" % method
   elif saveAll:
      cmds.append( 'no ebgp' )

   if bgpConfig.rpkiEbgpOvSend == 'isTrue':
      cmds.append( 'ebgp send' )
   elif saveAll:
      cmds.append( 'no ebgp send' )

   if bgpConfig.rpkiIbgpOvMethod.isSet:
      method = bgpConfig.rpkiIbgpOvMethod.value
      if method == ovMethodEnum.ovLocal:
         cmds.append( 'ibgp local' )
      elif method == ovMethodEnum.ovCommunity:
         cmds.append( 'ibgp community' )
      elif method == ovMethodEnum.ovPreferCommunity:
         cmds.append( 'ibgp prefer-community' )
      else:
         assert False, "Unrecognized method %s" % method
   elif saveAll:
      cmds.append( 'no ibgp' )

   if bgpConfig.rpkiIbgpOvSend == 'isTrue':
      cmds.append( 'ibgp send' )
   elif saveAll:
      cmds.append( 'no ibgp send' )

   if bgpConfig.rpkiOvRouteMap:
      cmds.append( 'validation route-map %s' % bgpConfig.rpkiOvRouteMap )
   elif saveAll:
      cmds.append( 'no validation route-map' )

   if bgpConfig.rpkiRedistOvMethod.isSet:
      method = bgpConfig.rpkiRedistOvMethod.value
      if method == ovMethodEnum.ovLocal:
         cmds.append( 'redistributed local' )
      else:
         assert False, "Unrecognized method %s" % method
   elif saveAll:
      cmds.append( 'no redistributed' )

   if bgpConfig.rpkiRedistOvRouteMap:
      cmds.append( 'redistributed validation route-map %s' %
                   bgpConfig.rpkiRedistOvRouteMap )
   elif saveAll:
      cmds.append( 'no redistributed validation route-map' )

   return cmds

def isDefaultPortInUseBasedOnTransportType( cacheConfig ):
   configuredTransportType = cacheConfig.transportConfig.transportAuthenticationType
   configuredPort = cacheConfig.port
   return ( ( configuredTransportType == 'tls' and
           configuredPort == cacheConfig.tlsPortDefault ) or
            ( configuredTransportType == 'tcp' and
           configuredPort == cacheConfig.tcpPortDefault ) )

def saveCacheConfig( cacheName, cacheConfig, saveAll=False ):
   cmds = []
   if cacheConfig.host:
      cmd = f"host {cacheConfig.host}"
      if cacheConfig.vrf != DEFAULT_VRF or saveAll:
         cmd += f" vrf {cacheConfig.vrf}"
      if saveAll or not isDefaultPortInUseBasedOnTransportType( cacheConfig ):
         cmd += f" port {cacheConfig.port}"
      cmds.append( cmd )
   elif saveAll:
      cmds.append( 'no host' )

   if cacheConfig.preference != cacheConfig.preferenceDefault or saveAll:
      cmds.append( f'preference {cacheConfig.preference}' )

   if cacheConfig.refreshInterval != 0:
      cmds.append( f'refresh-interval {cacheConfig.refreshInterval}' )
   elif saveAll:
      cmds.append( 'no refresh-interval' )

   if cacheConfig.retryInterval != 0:
      cmds.append( f'retry-interval {cacheConfig.retryInterval}' )
   elif saveAll:
      cmds.append( 'no retry-interval' )

   if cacheConfig.expireInterval != 0:
      cmds.append( f'expire-interval {cacheConfig.expireInterval}' )
   elif saveAll:
      cmds.append( 'no expire-interval' )

   if cacheConfig.sourceIntf != cacheConfig.sourceIntfDefault:
      cmds.append( f'local-interface {cacheConfig.sourceIntf}' )
   elif saveAll:
      cmds.append( 'no local-interface' )
   appendedTableNames = ' '.join(
      sorted( cacheConfig.roaTableConfig.roaTableNames ) )
   rpkiDefaults = Tac.Type( 'Rpki::RpkiDefaults' )
   if appendedTableNames != rpkiDefaults.defaultRoaTableName:
      cmds.append( f'roa table {appendedTableNames}' )
   elif saveAll:
      cmds.append( 'no roa table' )
   return cmds

def saveTcpKeepalive( cacheConfig, saveAll=False ):
   cmds = [ ]
   if ( cacheConfig.transportConfig.tcpKeepaliveOptions !=
           cacheConfig.transportConfig.tcpKeepaliveDefault ):
      tcpKeepaliveOptions = cacheConfig.transportConfig.tcpKeepaliveOptions
      cmd = 'tcp keepalive {} {} {}'.format(
         tcpKeepaliveOptions.idleTime,
         tcpKeepaliveOptions.probeInterval,
         tcpKeepaliveOptions.probeCount )
      cmds.append( cmd )
   elif saveAll:
      cmds.append( 'no tcp keepalive' )
   return cmds

def saveTransportTcp( cacheConfig, saveAll=False ):
   return saveTcpKeepalive( cacheConfig, saveAll )

def saveTransportTls( cacheConfig, saveAll=False ):
   cmds = saveTcpKeepalive( cacheConfig, saveAll )
   sslProfileName = cacheConfig.transportConfig.sslProfileName
   if sslProfileName != cacheConfig.transportConfig.sslProfileNameDefault:
      cmds.append( f"ssl profile {sslProfileName}" )
   elif saveAll:
      cmds.append( 'no ssl profile' )
   return cmds

@CliSave.saver( 'Rpki::CacheConfigDir', 'routing/rpki/cache/config',
                requireMounts=( 'routing/bgp/config', 'routing/bgp/asn/config' ) )
def saveRpkiCacheConfig( cacheConfigDir, root, requireMounts, options ):
   bgpConfig = requireMounts[ 'routing/bgp/config' ]
   asnConfig = requireMounts[ 'routing/bgp/asn/config' ]
   # Adding this to not add "routing bgp 0" to running config when bgp is not
   # configured.
   if bgpConfig.asNumber == 0:
      return

   bgpMode = root[ RouterBgpBaseConfigMode ].getOrCreateModeInstance( (
      bgpConfig.asNumber, isAsdotConfigured( asnConfig ) ) )

   for cacheName in sorted( cacheConfigDir.cacheConfig ):
      cacheConfig = cacheConfigDir.cacheConfig[ cacheName ]
      cacheMode = bgpMode[ RpkiCacheConfigCliSaveMode ].getOrCreateModeInstance(
         cacheName )
      cacheModeCmds = cacheMode[ 'Bgp.rpki.config' ]
      cmds = saveCacheConfig( cacheName, cacheConfig, options.saveAll )
      for cmd in cmds:
         cacheModeCmds.addCommand( cmd )
      if cacheConfig.transportConfig.transportAuthenticationType == 'tcp':
         transportTcpMode = cacheMode[
               RpkiTransportTcpCliSaveMode ].getSingletonInstance()
         transportModeCmds = transportTcpMode[ 'Bgp.rpki.transport-tcp.config' ]
         cmds = saveTransportTcp( cacheConfig, options.saveAll )
         for cmd in cmds:
            transportModeCmds.addCommand( cmd )
      if cacheConfig.transportConfig.transportAuthenticationType == 'tls':
         transportTlsMode = cacheMode[
            RpkiTransportTlsCliSaveMode ].getSingletonInstance()
         tlsModeCmds = transportTlsMode[ 'Bgp.rpki.transport-tls.config' ]
         cmds = saveTransportTls( cacheConfig, options.saveAll )
         for cmd in cmds:
            tlsModeCmds.addCommand( cmd )

@CliSave.saver( 'Routing::Bgp::Config', 'routing/bgp/config',
                requireMounts=( 'routing/bgp/asn/config', ) )
def saveRpkiConfig( bgpConfig, root, requireMounts, options ):
   asnConfig = requireMounts[ 'routing/bgp/asn/config' ]
   # If a BGP instance is not configured, do not generate any config
   if bgpConfig.asNumber == 0:
      return

   bgpMode = root[ RouterBgpBaseConfigMode ].getOrCreateModeInstance( (
      bgpConfig.asNumber, isAsdotConfigured( asnConfig ), ) )
   cmds = saveRpkiOriginValidationConfig( bgpConfig, options.saveAll )
   if cmds:
      rpkiMode = bgpMode[ RpkiOriginValidationCliSaveMode ].getSingletonInstance()
      rpkiModeCmds = rpkiMode[ 'Bgp.rpki.origin-validation' ]
      for cmd in cmds:
         rpkiModeCmds.addCommand( cmd )

# CliSave for neighbor rpki origin validation settings
def saveNeighborRpkiOvMethod( peer, peerConfig, saveAll ):
   ovMethodCmd = 'neighbor %s rpki origin-validation' % peer
   if peerConfig.rpkiOvMethodPresent:
      if peerConfig.rpkiOvMethod == ovMethodEnum.ovLocal:
         ovMethodCmd += ' local'
      elif peerConfig.rpkiOvMethod == ovMethodEnum.ovCommunity:
         ovMethodCmd += ' community'
      elif peerConfig.rpkiOvMethod == ovMethodEnum.ovPreferCommunity:
         ovMethodCmd += ' prefer-community'
      elif peerConfig.rpkiOvMethod == ovMethodEnum.ovDisabled:
         ovMethodCmd += ' disabled'
      else:
         assert False, "Unrecognized method %s" % peerConfig.rpkiOvMethod
      return ovMethodCmd
   elif saveAll and not peerConfig.isPeerGroupPeer:
      return 'default ' + ovMethodCmd
   return None

def saveNeighborRpkiOvSend( peer, peerConfig, saveAll ):
   ovSendCmd = 'neighbor %s rpki origin-validation send' % peer
   if peerConfig.rpkiOvSendPresent:
      if not peerConfig.rpkiOvSend:
         ovSendCmd += ' disabled'
      return ovSendCmd
   elif saveAll and not peerConfig.isPeerGroupPeer:
      return 'default ' + ovSendCmd
   return None

neighborRpkiHook.addExtension( saveNeighborRpkiOvMethod )
neighborRpkiHook.addExtension( saveNeighborRpkiOvSend )
