# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCli
from CliMode.Restconf import MgmtRestconfMode, RestconfTransportMode
from CliPlugin import OpenConfigCliLib
import CliMatcher
import ConfigMount
import DscpCliLib
from IpLibConsts import DEFAULT_VRF
import LazyMount


sslConfig = None
restconfConfig = None

# ------------------------------------------------------
# RESTCONF config commands
# ------------------------------------------------------

class MgmtRestconfConfigMode( MgmtRestconfMode, BasicCli.ConfigModeBase ):
   """CLI configuration mode 'management api restconf'."""

   name = "RESTCONF configuration"

   def __init__( self, parent, session ):
      self.config_ = restconfConfig

      MgmtRestconfMode.__init__( self, "api-restconf" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

def gotoMgmtRestconfConfigMode( mode, args ):
   childMode = mode.childMode( MgmtRestconfConfigMode )
   mode.session_.gotoChildMode( childMode )

def noMgmtRestconfConfigMode( mode, args ):
   """Resets RESTCONF configuration to default."""
   restconfConfig.enabled = False
   for name in restconfConfig.endpoints:
      noRestconfTransportConfigMode( mode, { 'TRANSPORT_NAME': name } )

class RestconfTransportConfigMode( RestconfTransportMode, BasicCli.ConfigModeBase ):
   """CLI configuration submode 'transport https <name>'."""

   name = 'Transport for RESTCONF'

   def __init__( self, parent, session, name ):
      self.config_ = restconfConfig
      self.name = name

      RestconfTransportMode.__init__( self, name )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

# ---------------------------------------------------------------------
# switch(config-mgmt-api-restconf)#transport https <name>
# ---------------------------------------------------------------------
def gotoRestconfTransportConfigMode( mode, args ):
   name = args[ 'TRANSPORT_NAME' ]
   if name not in restconfConfig.endpoints:
      if OpenConfigCliLib.otherEnabledTransportExists( mode, name):
         return

      endpoint = restconfConfig.newEndpoints( name )
      endpoint.transport = 'https'
      endpoint.vrfName = DEFAULT_VRF
      # since 'initially' attributes don't get attrlogged, be explicit for now
      endpoint.port = endpoint.portDefault
      endpoint.enabled = True

      serviceStatusName = "restconf-" + name
      OpenConfigCliLib.createSslServiceStatusEntity( serviceStatusName )

      OpenConfigCliLib.updateLevelEnabledFlag( restconfConfig )

   childMode = mode.childMode( RestconfTransportConfigMode, name=name )
   mode.session_.gotoChildMode( childMode )

def noRestconfTransportConfigMode( mode, args ):
   name = args[ 'TRANSPORT_NAME' ]
   endpoint = restconfConfig.endpoints.get( name )
   if endpoint is not None:
      if endpoint.enabled:
         endpoint.enabled = False
      endpoint.port = endpoint.portDefault

      serviceStatusName = "restconf-" + name
      OpenConfigCliLib.deleteSslServiceStatusEntity( serviceStatusName )

      try:
         del restconfConfig.endpoints[ name ]
      except KeyError:
         pass

   OpenConfigCliLib.updateLevelEnabledFlag( restconfConfig )

def shutdown( mode, args ):
   endpoint = mode.config_.endpoints[ mode.name ]
   if endpoint.enabled:
      endpoint.enabled = False
      OpenConfigCliLib.updateLevelEnabledFlag( restconfConfig )

def noShutdown( mode, args ):
   # This stanza can only be enabled if there is no other
   # stanza with same 'transport' type already enabled
   endpoint = mode.config_.endpoints.get( mode.name )
   for e in mode.config_.endpoints.values():
      if e.name != endpoint.name and e.enabled and e.transport == endpoint.transport:
         mode.addError( "transport '%s' of type '%s' already "
               "enabled; can not enable another" % ( e.name, e.transport ) )
         return
   endpoint.enabled = True
   mode.config_.enabled = True

def setSslProfile( mode, args ):
   profileName = args[ 'PROFILENAME' ]
   mode.config_.endpoints[ mode.name ].sslProfile = profileName

def noSslProfile( mode, args ):
   mode.config_.endpoints[ mode.name ].sslProfile = ''

profileNameMatcher = CliMatcher.DynamicNameMatcher(
      lambda mode: sslConfig.profileConfig,
      'Profile name')

def setVrfName( mode, args ):
   endpoint = mode.config_.endpoints[ mode.name ]
   vrfName = args.get( 'VRFNAME', DEFAULT_VRF )
   if vrfName == endpoint.vrfName:
      return
   if endpoint.serviceAcl:
      if OpenConfigCliLib.existsOtherTransportSameVrfWithServiceAcl( mode, vrfName ):
         return
      OpenConfigCliLib.noServiceAcl( mode, 'restconf', endpoint )
   endpoint.vrfName = vrfName
   if endpoint.serviceAcl:
      OpenConfigCliLib.setServiceAcl( mode, 'restconf', endpoint )

def setPort( mode, args ):
   endpoint = mode.config_.endpoints[ mode.name ]
   port = args.get( 'PORT', endpoint.portDefault )
   if port == endpoint.port:
      return
   endpoint.port = port
   if endpoint.serviceAcl:
      OpenConfigCliLib.setServiceAcl( mode, 'restconf', endpoint )

# ---------------------------------------------------------------------
# switch(config-mgmt-api-restconf-transport-<name>)# qos dscp <dscpValue>
# ---------------------------------------------------------------------
def setDscp( mode, args ):
   mode.config_.endpoints[ mode.name ].qosDscp = args[ 'DSCP' ]

def noDscp( mode, args ):
   mode.config_.endpoints[ mode.name ].qosDscp = 0

DscpCliLib.addQosDscpCommandClass( RestconfTransportConfigMode, setDscp, noDscp )

# -----------------------------------------------------------------------------------
# switch(config-mgmt-api-restconf-transport-<name>)# ip access-group <access-list>
# -----------------------------------------------------------------------------------
def setRestconfAcl( mode, args ):
   endpoint = mode.config_.endpoints.get( mode.name )
   if ( endpoint.serviceAcl
        or not OpenConfigCliLib.existsOtherTransportSameVrfWithServiceAcl( 
           mode, endpoint.vrfName ) ):
      endpoint.serviceAcl = args[ 'ACLNAME' ]
      OpenConfigCliLib.setServiceAcl( mode, 'restconf', endpoint )

def noRestconfAcl( mode, args ):
   endpoint = mode.config_.endpoints.get( mode.name )
   if endpoint.serviceAcl:
      endpoint.serviceAcl = ""
      OpenConfigCliLib.noServiceAcl( mode, 'restconf', endpoint )

def Plugin( entityManager ):
   global restconfConfig, sslConfig
   restconfConfig = ConfigMount.mount( entityManager, "mgmt/restconf/config",
                                   "Restconf::Config", "w" )
   sslConfig = LazyMount.mount( entityManager,
                                "mgmt/security/ssl/config",
                                "Mgmt::Security::Ssl::Config",
                                "r" )
