# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AsuPStore
import SharedMem
import Smash
import Shark
from collections import defaultdict

class NatPStoreEventHandler( AsuPStore.PStoreEventHandler ):
   def __init__( self,
                 natStatus,
                 dynamicConnectionStatus,
                 dynamicConnectionMappingStatus ):
      AsuPStore.PStoreEventHandler.__init__( self )
      self.natStatus = natStatus
      self.dynamicConnectionStatus = dynamicConnectionStatus
      self.dynamicConnectionMappingStatus = dynamicConnectionMappingStatus

   def getDynamicConnections( self ):
      def _createConnection():
         rv = {}
         rv[ 'globalAddress' ] = ''
         rv[ 'target' ] = ''
         rv[ 'fullCone' ] = ''
         rv[ 'addrOnly' ] = ''
         rv[ 'twiceNat' ] = ''
         rv[ 'established' ] = ''
         rv[ 'cmrAclName' ] = ''
         rv[ 'cmrGroup' ] = ''
         rv[ 'connMap' ] = ''
         return rv

      def _setGlobalAddress( key, globalAddress ):
         dynamicConnection[ key ][ 'globalAddress' ] = globalAddress

      def _setTarget( key, target ):
         dynamicConnection[ key ][ 'target' ] = target

      def _setFullCone( key, fullCone ):
         dynamicConnection[ key ][ 'fullCone' ] = fullCone

      def _setAddrOnly( key, addrOnly ):
         dynamicConnection[ key ][ 'addrOnly' ] = addrOnly

      def _setTwiceNat( key, twiceNat ):
         dynamicConnection[ key ][ 'twiceNat' ] = twiceNat

      def _setEstablished( key, established ):
         dynamicConnection[ key ][ 'established' ] = established

      def _setConnMarkRule( key, connMark ):
         assert connMark in self.natStatus.connMarkRule
         cmr = self.natStatus.connMarkRule[ connMark ]
         cmrKey = cmr.cmrKey
         dynamicConnection[ key ][ 'cmrAclName' ] = cmrKey.aclName
         try:
            dynamicConnection[ key ][ 'cmrGroup' ] = cmrKey.group
         except AttributeError:
            dynamicConnection[ key ][ 'cmrGroup' ] = cmr.group

      def _setConnMap( key, connMap ):
         dynamicConnection[ key ][ 'connMap' ] = connMap

      dynamicConnection = defaultdict( _createConnection )

      # Entries from the dynamicConnectionStatus
      for dcKey, status in self.dynamicConnectionStatus.dynamicConnection.items():
         key = dcKey.natIntf + dcKey.connTuple.stringValue()
         _setGlobalAddress( key, status.globalAddress.stringValue() )
         _setTarget( key, status.target() )
         _setFullCone( key, status.fullCone() )
         _setAddrOnly( key, status.addrOnly() )
         _setTwiceNat( key, status.twiceNat() )
         _setEstablished( key, status.established() )
         _setConnMarkRule( key, status.connMark )
         _setConnMap( key, False )

      # Entries from the dynamicConnectionMapping (FullCone, AddrOnly)
      for intfId, intfMapping in (
          self.dynamicConnectionMappingStatus.intfMapping.items() ):
         for connMapping in intfMapping.connMapping.values():
            globalAddress = connMapping.globalAddress.stringValue()
            mapType = connMapping.mapType
            assert mapType in [ 'mappingFullCone', 'mappingAddrOnly' ]
            fullCone = mapType == 'mappingFullCone'
            addrOnly = mapType == 'mappingAddrOnly'
            for connTuple, conn in connMapping.conn.items():
               key = intfId + connTuple.stringValue()
               _setGlobalAddress( key, globalAddress )
               _setTarget( key, conn.target )
               _setFullCone( key, fullCone )
               _setAddrOnly( key, addrOnly )
               _setTwiceNat( key, False )
               _setEstablished( key, conn.established )
               _setConnMarkRule( key, conn.connMark )
               _setConnMap( key, True )

      return dynamicConnection

   def getNatIntfs( self ):
      def _createNatIntf():
         rv = {}
         rv[ 'name' ] = ''
         return rv

      def _setNatIntfName( natIntf, name ):
         natIntfs[ natIntf ][ 'name' ] = name

      natIntfs = defaultdict( _createNatIntf )

      # Entries from the natStatus intfConfigKey collection
      for natIntf, natIntfName in self.natStatus.intfConfigKey.items():
         _setNatIntfName( natIntf, natIntfName )

      return natIntfs

   def save( self, pStoreIO ):
      pStoreIO.set( 'dynamicConnections', self.getDynamicConnections() )
      pStoreIO.set( 'natIntfs', self.getNatIntfs() )

   def getSupportedKeys( self ):
      return [ 'dynamicConnections', 'natIntfs' ]

   def getKeys( self ):
      return [ 'dynamicConnections', 'natIntfs' ]

def Plugin( ctx ):
   featureName = 'Nat'

   if ctx.opcode() == 'GetSupportedKeys':
      natStatus = None
      dynamicConnectionStatus = None
      dynamicConnectionMappingStatus = None
   else:
      entityManager = ctx.entityManager()
      sharedMemEm = SharedMem.entityManager( sysdbEm=entityManager )

      #################################
      # Sysdb mounts
      #################################
      mg = entityManager.mountGroup()
      natStatus = mg.mount( 'ip/nat/status', 'Ip::Nat::Status', 'r' )
      mg.close( blocking=True )

      #################################
      # Smash mounts
      #################################
      # We are using "shadow" instead of "reader" to avoid mismatch with NatCli.py
      smashMountInfo = Smash.mountInfo( 'shadow' )
      dynamicConnectionStatus = sharedMemEm.doMount(
         "ip/nat/status/dynamicConnection",
         "Ip::Nat::DynamicConnectionStatus",
         smashMountInfo )

      #################################
      # Shark mounts
      #################################
      sharedMemMg = sharedMemEm.getMountGroup()
      sharkMountInfo = Shark.mountInfo( "shadow" )
      dynamicConnectionMappingStatus = sharedMemMg.doMount(
         "nat/connMapStatus",
         "Ip::Nat::DynamicConnectionMappingStatus",
         sharkMountInfo )
      sharedMemMg.doClose()

   ctx.registerAsuPStoreEventHandler(
         featureName,
         NatPStoreEventHandler( natStatus,
                                dynamicConnectionStatus,
                                dynamicConnectionMappingStatus ) )
