#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from __future__ import absolute_import, division, print_function

import re
from operator import attrgetter

import TableOutput
from CliModel import Dict, Enum, Int, List, Model, Str, Submodel


class FeatureModel( Model ):
   __public__ = False
   details = List( valueType=str, help="Inconsistencies for the feature" )

   def setTable( self, table, feature ):
      if self.details:
         table.newRow( feature, "warn", "\n".join( self.details ) )

class AclTypeModel( Model ):
   __public__ = False
   ip = Submodel( valueType=FeatureModel, help="IP ACL inconsistencies",
                  optional=True )
   ipv6 = Submodel( valueType=FeatureModel, help="IPv6 ACL inconsistencies",
                    optional=True )
   mac = Submodel( valueType=FeatureModel, help="MAC ACL inconsistencies",
                   optional=True )

   def setTable( self, table ):
      if self.ip is not None:
         self.ip.setTable( table, "IP access list" )
      if self.ipv6 is not None:
         self.ipv6.setTable( table, "IPv6 access list" )
      if self.mac is not None:
         self.mac.setTable( table, "MAC access list" )

class PfxListTypeModel( Model ):
   __public__ = False
   ip = Submodel( valueType=FeatureModel, help="IP prefix list inconsistencies",
                  optional=True )
   ipv6 = Submodel( valueType=FeatureModel, help="IPv6 prefix list inconsistencies",
                    optional=True )

   def setTable( self, table ):
      if self.ip is not None:
         self.ip.setTable( table, "IP prefix list" )
      if self.ipv6 is not None:
         self.ipv6.setTable( table, "IPv6 prefix list" )

class CommListTypeModel( Model ):
   __public__ = False
   ip = Submodel( valueType=FeatureModel, help="IP community list inconsistencies",
                  optional=True )

   def setTable( self, table ):
      if self.ip is not None:
         self.ip.setTable( table, "IP community list" )

class UndefinedReferenceCheckerModel( Model ):
   __public__ = False
   acls = Submodel( valueType=AclTypeModel, help="Undefined references to ACLs",
                    optional=True )
   pfxLists = Submodel( valueType=PfxListTypeModel,
                        help="Undefined references to prefix lists", optional=True )
   commLists = Submodel( valueType=CommListTypeModel,
                         help="Undefined references to community lists",
                         optional=True )
   # There is no meaningful way to categorize route maps in a similar way that
   # ACLs, prefix lists, and community lists are categorized above, so another
   # submodel for route maps is unnecessary.
   # It's sufficient just to set it to a FeatureModel.
   routeMaps = Submodel( valueType=FeatureModel,
                         help="Undefined references to route maps", optional=True )

   def setTable( self, table ):
      # print features in alphabetical order
      for featureType in [ self.acls, self.commLists, self.pfxLists ]:
         if featureType is not None:
            featureType.setTable( table )
      if self.routeMaps is not None:
         self.routeMaps.setTable( table, "Route map" )

class ConfigConsistencyModel( Model ):
   __public__ = False
   undefinedReferences = Submodel( valueType=UndefinedReferenceCheckerModel,
                                   help="Undefined reference checker",
                                   optional=True )

   def render( self ):
      ret = ""
      headers = [ "Feature", "Result", "Detail" ]
      fmt = TableOutput.Format( justify='left' )
      fmt.noPadLeftIs( True )
      fmts = [ fmt for _ in headers ]

      if self.undefinedReferences is not None:
         ret += "\nUndefined references\n"
         table = TableOutput.createTable( headers )
         table.formatColumns( *fmts )
         emptyTableStr = table.output()
         self.undefinedReferences.setTable( table )
         tableStr = table.output()
         if tableStr != emptyTableStr:
            ret += "\n" + tableStr[ : -1 ]

      print( ret )

# -----------------------------------------------------------------------------------
# "table-item-ish" config consistency models
#
# The command is `show configuration consistency CLUSTER [ CATEGORY ]`
# CLUSTER: cluster name. E.g. `nat`, and etc.
# CATEGORY: category name, optional. E.g. `access-list`. If missing, all categories
#     under this cluster will be checked
#
# A cluster is mapped to a ConfigConsistencyMultiTableModel.
# A category is mapped to 1 or multiple ConfigConsistencyTableModel. That is, a cat-
# gory check can generate multiple tables.
# -----------------------------------------------------------------------------------

_COLUMN_WIDTHS = {
   'default': {
      'name': 35,
      'status': 9,
      'description': 35
   }
}

class ConfigConsistencyItemModel( Model ):
   _priority = Int( help='Priority for sorting', default=1 )
   name = Str( help='Name of item', default='Unknown Item' )
   status = Enum( help='Item consistency level',
                  values=( 'fail', 'warn', 'ok' ),
                  default='ok' )
   description = Str( help='Description about the consistency item', default='' )

class ConfigConsistencyTableModel( Model ):
   _priority = Int( help='Priority for sorting', default=1 )
   name = Str( help='Name of table', default='Unknown Table' )
   headings = List( valueType=str, help='Table header fot this table',
                    default=[ 'Item', 'Status', 'Description' ] )
   status = Enum( help='Table consistency level',
                  values=( 'fail', 'warn', 'ok' ),
                  default='ok' )
   items = Dict( help='Config consistency items, keyed by item names',
                 keyType=str, valueType=ConfigConsistencyItemModel )

   def getFormat( self ):
      fmtter = TableOutput.TableFormatter()

      # set headings
      headingsFmt = TableOutput.Format( isHeading=True, border=True )
      headingsFmt.noPadLeftIs( True )
      headingsFmt.padLimitIs( True )
      headingsFmt.noTrailingSpaceIs( True )
      fmtter.startRow( headingsFmt )
      fmtter.newCell( *self.headings )

      # set columns
      columnFmts = []
      for w in _COLUMN_WIDTHS[ 'default' ].values():
         f = TableOutput.Format( justify='left', minWidth=w, maxWidth=w,
                                 isHeading=False, border=False, wrap=True )
         f.noPadLeftIs( True )
         f.padLimitIs( True )
         f.noTrailingSpaceIs( True )
         columnFmts.append( f )
      fmtter.formatColumns( *columnFmts )

      return fmtter

   def render( self ):
      if self.status == 'ok':
         return

      fmtter = self.getFormat()
      for item in sorted( self.items.values(), key=attrgetter( '_priority' ) ):
         if item.status != 'ok':
            cells = []
            for att in item:
               if att.startswith( '_' ):
                  continue
               cells.append( getattr( item, att ) )
            fmtter.newRow( *cells )
      print( f'{ self.name }\n\n{ fmtter.output() }' )

class ConfigConsistencyMultiTableModel( Model ):
   tables = Dict( help='Config check result tables, keyed by table names',
                  keyType=str, valueType=ConfigConsistencyTableModel )

   def render( self ):
      for t in sorted( self.tables.values(), key=attrgetter( '_priority' ) ):
         t.render()

class ConfigConsistencyMultiTableModelBuilder:
   """
   Builder with helper functions to assemble a ConfigConsistencyMultiTableModel.
   """
   def __init__( self, model=None ):
      self.model = ConfigConsistencyMultiTableModel() if model is None else model
      self.stLevels = { "ok": 1, "warn": 2, "fail": 3 }

   def addTable( self, name, headings=None, status=None, items=None,
                 priority=None ):
      table = ConfigConsistencyTableModel( name=name )
      if headings is not None:
         table.headings = headings
      if status is not None:
         table.status = status
      if priority is not None:
         setattr( table, '_priority', priority )
      if items is not None:
         table.items = items
      self.model.tables[ name ] = table
      return self

   def addItem( self, tableName, itemName, status=None, description=None,
                priority=None ):
      if tableName not in self.model.tables:
         self.addTable( tableName )
      table = self.model.tables[ tableName ]
      item = ConfigConsistencyItemModel( name=itemName )
      if status is not None:
         item.status = status
         table.status = self.worst( table.status, item.status )
      if description is not None:
         item.description = description
      if priority is not None:
         setattr( item, '_priority', priority )
      self.model.tables[ tableName ].items[ itemName ] = item
      return self

   def patchTable( self, name, headings=None, status=None, items=None,
                   priority=None ):
      assert name in self.model.tables
      table = self.model.tables[ name ]
      if headings is not None:
         table.headings = headings
      if status is not None:
         table.status = status
      if items is not None:
         table.items = items
      if priority is not None:
         setattr( table, '_priority', priority )
      return self

   def patchItem( self, tableName, itemName, status=None, description=None,
                  descriptionToAppend=None, priority=None ):
      assert tableName in self.model.tables
      assert itemName in self.model.tables[ tableName ].items
      table = self.model.tables[ tableName ]
      item = self.model.tables[ tableName ].items[ itemName ]
      if status is not None:
         item.status = status
         table.status = self.worst( table.status, item.status )
      if description is not None:
         item.description = description
      if descriptionToAppend is not None:
         item.description += descriptionToAppend
      if priority is not None:
         setattr( item, '_priority', priority )
      return self

   def build( self ):
      return self.model

   # --------------------------------------------------------------------------------
   # helper functions below
   # --------------------------------------------------------------------------------

   def worst( self, st1, st2 ):
      return st1 if self.stLevels[ st1 ] > self.stLevels[ st2 ] else st2

   def worse( self, st1, st2 ):
      return self.stLevels[ st1 ] > self.stLevels[ st2 ]

   def better( self, st1, st2 ):
      return self.stLevels[ st1 ] < self.stLevels[ st2 ]

   def hasTable( self, name ):
      return name in self.model.tables

   def hasItem( self, tableName, itemName ):
      return self.hasTable( tableName ) and \
             itemName in self.model.tables[ tableName ].items

class ConfigConsistencyMultiTableModelParser:
   @classmethod
   def parse( cls, output ):
      """
      Parse multi-tabled output to dictionary.

      From:
      ACL Overlap

      Access List                        Result   Description
      ---------------------------------- -------- ----------------------------------
      acl1                               fail     Overlaps with acl2,
                                                  acl3WithASuperLongName
      acl2                               fail     Overlaps with
                                                  acl3WithASuperLongName

      Other Table
      ...

      To:
      {
         "tables": {
            "ACL Overlap": {
                  "name": "ACL Overlap",
                  "headings": [
                     "Access List",
                     "Result",
                     "Description"
                  ],
                  "status": "fail",
                  "items": {
                     "acl1": {
                        "name": "acl1",
                        "status": "fail",
                        "description": "Overlaps with acl2, acl3WithASuperLongName"
                     },
                     "acl2": {
                        "name": "acl2",
                        "status": "fail",
                        "description": "Overlaps with acl3WithASuperLongName"
                     }
                  }
            }
         }
      }
      """
      def splitRow( row, widthType='default' ):
         l, r = 0, 0
         res = []
         for width in _COLUMN_WIDTHS[ widthType ].values():
            r += width
            res.append( row[ l : r ].strip() )
            l = r
         return res

      def addItem( name, status, desc ):
         t[ "items" ][ name ] = {
            "name": name,
            "status": status,
            "description": desc
         }
         if status == 'fail':
            t[ "status" ] = 'fail'

      model = { "tables": {} }
      tableName = None
      for i, tableNameOrContent in enumerate( re.split( r"(?<=\n)\n", output ) ):
         if i % 2 == 0:
            tableName = tableNameOrContent[ : -1 ]
            model[ "tables" ][ tableName ] = {
               "name": tableName
            }
         else:
            t = model[ "tables" ][ tableName ]
            lines = re.split( "\n", tableNameOrContent )
            # parse headings
            t[ "headings" ] = [ *splitRow( lines[ 0 ] ) ]
            # parse items
            t[ "items" ] = {}
            t[ "status" ] = 'warn'
            prevName, prevStatus, prevDesc = None, None, None
            for line in lines[ 2 : ]:
               name, status, desc, *_ = splitRow( line )
               if status == '':
                  if name != '':
                     prevName += f' { name }'
                  if desc != '':
                     prevDesc += f' { desc }'
               else:
                  if prevStatus is not None:
                     addItem( prevName, prevStatus, prevDesc )
                  prevName, prevStatus, prevDesc = name, status, desc
            if prevStatus is not None:
               addItem( prevName, prevStatus, prevDesc )

      return model
