# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from collections import namedtuple
import CliSave
from CliSavePlugin import IntfCliSave
import Tac

IrVxlanIpv4TunnelId = Tac.Type( "Multicast::Tunnel::IrVxlanIpv4TunnelId" )
McastTunnelIntfId = Tac.Type( "Multicast::Tunnel::McastTunnelIntfId" )
PimVxlanTunnel = Tac.Type( "Multicast::Tunnel::PimVxlanTunnel" )
McastTunnelType = Tac.Type( "Multicast::Tunnel::McastTunnelType" )

pimTunnelTypeToKeyword = { 'pimsmTunnel': 'pim-sm',
                           'pimssmTunnel': 'pim-ssm' }

def intfBackwardConvert( intfId, af ):
   keyword = None
   tunnelEndpoint = None
   tunnelId = McastTunnelIntfId.tunnelId( intfId )
   if tunnelId.tunnelType == McastTunnelType.pimVxlan:
      keyword = 'tunnel'
      tunnelEndpoint = af + '-' + pimTunnelTypeToKeyword[
         PimVxlanTunnel.makePimTunnelType( intfId ) ]
   elif tunnelId.tunnelType == McastTunnelType.irVxlanIpv4:
      keyword = 'vtep'
      tunnelEndpoint = IrVxlanIpv4TunnelId.vtepIp( intfId )
   return keyword, tunnelEndpoint

def getIntfCmdsHelper( vtepCmd, tunnelCmd, portCmd, etCmd, demuxIntfs ):
   intfCmds = []
   if demuxIntfs.vtepIntfs:
      vtepCmd += ' '.join( sorted( demuxIntfs.vtepIntfs ) )
      intfCmds.append( vtepCmd )
   if demuxIntfs.tunnelIntfs:
      tunnelCmd += ' '.join( sorted( demuxIntfs.tunnelIntfs ) )
      intfCmds.append( tunnelCmd )
   if demuxIntfs.portIntfs:
      portCmd += ','.join( sorted( demuxIntfs.portIntfs ) )
      intfCmds.append( portCmd )
   if demuxIntfs.etIntfs:
      etCmd += ','.join( sorted( demuxIntfs.etIntfs ) )
      intfCmds.append( etCmd )
   return intfCmds

def demuxByInterfaceTypes( intfs, af ):
   demuxIntfs = namedtuple( 'demuxIntfs',
                            'vtepIntfs tunnelIntfs portIntfs etIntfs' )
   vtepIntfs = []
   tunnelIntfs = []
   portIntfs = []
   etIntfs = []
   for intf in intfs:
      if McastTunnelIntfId.isMcastTunnelIntfId( intf ):
         keyword, tunnelEndpoint = intfBackwardConvert( intf, af )
         if keyword == 'vtep':
            vtepIntfs.append( tunnelEndpoint )
         else:
            tunnelIntfs.append( tunnelEndpoint )
      elif 'Port-Channel' in intf:
         portIntfs.append( intf )
      else:
         etIntfs.append( intf )
   return demuxIntfs( vtepIntfs, tunnelIntfs, portIntfs, etIntfs )

def getIntfCmds( membershipJoinStatus, af ):
   intfCmds = []
   vlanMembershipJoin = membershipJoinStatus.vlan
   for vlanId in sorted( vlanMembershipJoin ):
      routerVtepCmd = 'vxlan vlan %d member vtep ' % vlanId
      routerTunnelCmd = 'vxlan vlan %d member tunnel ' % vlanId
      routerPortCmd = 'vxlan vlan %d member ' % vlanId
      routerEtCmd = 'vxlan vlan %d member ' % vlanId
      demuxIntfs = demuxByInterfaceTypes(
            sorted( vlanMembershipJoin[ vlanId ].routerIntf ), af )
      intfCmds += getIntfCmdsHelper( routerVtepCmd, routerTunnelCmd,
                                     routerPortCmd, routerEtCmd,
                                     demuxIntfs )

      for groupAddr in sorted( vlanMembershipJoin[ vlanId ].group ):
         sourceMembershipJoin = vlanMembershipJoin[ vlanId ] \
                                .group[ groupAddr ].source
         for sourceAddr in sorted( sourceMembershipJoin ):
            includeVtepCmd = "vxlan vlan %d member %s %s vtep " \
                             % ( vlanId, groupAddr, sourceAddr )
            includeTunnelCmd = "vxlan vlan %d member %s %s tunnel " \
                               % ( vlanId, groupAddr, sourceAddr )
            includePortCmd = "vxlan vlan %d member %s %s " \
                             % ( vlanId, groupAddr, sourceAddr )
            includeEtCmd = "vxlan vlan %d member %s %s " \
                             % ( vlanId, groupAddr, sourceAddr )
            demuxIntfs = demuxByInterfaceTypes(
                  sorted( sourceMembershipJoin[ sourceAddr ].includeIntf ), af )
            intfCmds += getIntfCmdsHelper( includeVtepCmd, includeTunnelCmd,
                                           includePortCmd, includeEtCmd,
                                           demuxIntfs )

            excludeVtepCmd = "vxlan vlan %d member exclude %s %s vtep " \
                             % ( vlanId, groupAddr, sourceAddr )
            excludeTunnelCmd = "vxlan vlan %d member exclude %s %s tunnel " \
                               % ( vlanId, groupAddr, sourceAddr )
            excludePortCmd = "vxlan vlan %d member exclude %s %s " \
                             % ( vlanId, groupAddr, sourceAddr )
            excludeEtCmd = "vxlan vlan %d member exclude %s %s " \
                             % ( vlanId, groupAddr, sourceAddr )
            demuxIntfs = demuxByInterfaceTypes(
                  sorted( sourceMembershipJoin[ sourceAddr ].excludeIntf ), af )
            intfCmds += getIntfCmdsHelper( excludeVtepCmd, excludeTunnelCmd,
                                           excludePortCmd, excludeEtCmd,
                                           demuxIntfs )
   return intfCmds

@CliSave.saver( 'Irb::Multicast::Gmp::MembershipJoinStatusCli',
                'multicast/ipv4/irb/membership/join/cli' )
def saveIpv4MembershipJoinStatus( membershipJoinStatus, root, requireMounts,
                                  options ):
   intfCmds = getIntfCmds( membershipJoinStatus, 'ipv4' )
   if intfCmds:
      mode = root[ IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( 'Vxlan1' )
      cmds = mode[ 'McastVpn.VxlanIntf' ]
      for cmd in intfCmds:
         cmds.addCommand( cmd )

@CliSave.saver( 'Irb::Multicast::Gmp::MembershipJoinStatusCli',
                'multicast/ipv6/irb/membership/join/cli' )
def saveIpv6MembershipJoinStatus( membershipJoinStatus, root, requireMounts,
                                  options ):
   intfCmds = getIntfCmds( membershipJoinStatus, 'ipv6' )
   if intfCmds:
      mode = root[ IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( 'Vxlan1' )
      cmds = mode[ 'McastVpn.VxlanIntf' ]
      for cmd in intfCmds:
         cmds.addCommand( cmd )
