#!/usr/bin/env python3
# Copyright (c) 2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import hashlib
import sys, hmac, array, optparse, struct, re # pylint: disable=deprecated-module
import Sol, Tac
from M2Crypto import RSA

keySize = 256
chunkSize = 64

ARunpubfile = "/usr/share/Sol/ARun-pub.pem"

def keysInPrefdl( chipId ):
   return ( chipId is not None and
            ( chipId.find("Fabric") == 0 or
              chipId.find("Linecard") == 0 or
              chipId.find("supervisor") == 0))


def readPrefdl( chip ):
   chipRe = re.search( r'(supervisor|fabric|linecard)(\d+)', chip,
                       re.IGNORECASE )
   if chipRe:
      cmdOpt = '--%s' % chipRe.group( 1 ).lower()
      cmdOptArg = int( chipRe.group( 2 ) )

      return Tac.run( [ "modulargenprefdl", cmdOpt, cmdOptArg ],
            stdout=Tac.CAPTURE )

   return None


def fetchSolKey( chip, keyName ):
   """ Fetch the sol keys.  On Napa & Ahwahnee they are stored in the
       MAC SEEPROM.  On a M8 line or fabric card, they are stored in
       the FDL SEEPROM.  keyName is either SolKey (chip_pub) or SolMfg. """

   keys = { 'SolKey' : ( 'SolKey', '1.86' ),
            'SolMfg' : ( 'SolMfg', '1.87' ) }
   if keyName in keys:
      key = keys[ keyName ]
   else:
      print( "error: keyName must be one of %s" % list( keys ) )
      assert ( 1 ) # pylint: disable=superfluous-parens
   if keysInPrefdl( chip ):
      prefdl = readPrefdl( chip )

      solAsciiRe = re.search( r"%s\s*:\s*(\S+)" % key[ 0 ], prefdl )
      if solAsciiRe is None:
         print( "ERROR: couldn't find key in prefdl [%s]" % key[ 0 ] )
         print( "Please make sure that the security chip been programmed." )
         sys.exit( 1 )
      else:
         solAscii = solAsciiRe.group( 1 )
      chip = ''
      for i in range( 256 ):
         chip += chr( int( solAscii[ i*2: (i*2)+2 ], 16 ))
   else:
      output = Tac.run( ["idseeprom", "-d", key[1], "read"],
                        stdout=Tac.CAPTURE )
      chip = output[0:256]
   return chip

def fetchBoardAttributes( chip ):
   """ Get the PCA, SN, and MAC address for the FRU. """

   if keysInPrefdl( chip ):
      output = readPrefdl( chip )

      pca = re.search( r"PCA\s*:\s*(\S+)", output ).group( 1 )
      sn = re.search( r"SerialNumber\s*:\s*(\S+)", output ).group( 1 )
      macAddrBase = re.search( r"MacAddrBase\s*:\s*(\S+)", output )
      if macAddrBase is None:
         mac = "None"
      else:
         mac = "".join( re.split( "[:-]", macAddrBase.group(1) ) )

      # In my great wisdom, I added a linefeed (^J, hex 10) at the end of
      # the MAC during the generation process.  Thus, we add the character
      # here.  In the non-prefdl case the "^J" is already there...
      mac = mac + "\012"
   else:
      output = Tac.run( [ "/bin/bash", "-c",
                          "idseeprom -d 1.82 read | prefdl --decode -" ],
                        stdout=Tac.CAPTURE )
      pca = re.search( r"PCA: (\S+)", output ).group( 1 )
      sn = re.search( r"SerialNumber: (\S+)", output ).group( 1 )

      macRaw = Tac.run( [ "macseeprom", "read" ], stdout=Tac.CAPTURE )
      mac = "".join( re.split( "[:-]", macRaw ) )
      if mac == "000000000000\012":  # legacy format. try Bodega
         output = Tac.run( [ "macseeprom", "--system", "bodega", "read" ],
                           stdout=Tac.CAPTURE )
         macAddrBase = re.search( r"MacAddrBase: (\S+)", output ).group( 1 )
         mac = "".join( re.split( "[:-]", macAddrBase ) ) + "\012"
   return (pca, sn, mac)

k = {}
def readKGFile( kgfile ):
   kgfp = open( kgfile ) # pylint: disable=consider-using-with
   for x in range( 4 ):
      k[x] = kgfp.read( 20 )

def initialize( opt ):
   handle = Sol.setup( opt.chip, opt.sysname )
   if handle is None:
      sys.exit( 1 )
   return handle

def testGenMV( handle ):
   Sol.call( handle, Sol.SolCmdGenM )
   M = Sol.get( handle, 20 )
   V = {}
   for idx in range( 4 ):
      # chip version
      Sol.call( handle, Sol.SolCmdGetV, data=idx )
      V[idx] = Sol.get( handle, 20 )

      # sw version
      v = hmac.new( k[idx], M, hashlib.sha1 ).digest()

      if V[idx] != array.array('B', v):
         print( idx )
         Sol.dump1( V[idx])
         Sol.dump1( array.array('B', v))
         sys.exit( 1 )

def testAuthM( handle, kgfile=None ):
   """ Test AuthM. If the kgfile isn't available, use some pre-canned values """

   if kgfile:
      M = "abcde12345qwert7890-"
      N = "zxcvbASDFGNM<>?&*()_"

      for ki in range( 4 ):
         Sol.set( handle, array.array('B', M+N) )
         Sol.call( handle, Sol.SolCmdAuthM, data=ki )
         H = Sol.get( handle, 20 )

         v = hmac.new( k[ki], M, hashlib.sha1 ).digest()
         h = hmac.new( v, N, hashlib.sha1 ).digest()

         if array.array('B', h) != H:
            print( "AuthM failure:\nchip: %s\n  sw: %s" %
                  ( Sol.format1( array.array( 'B', H ) ),
                    Sol.format1( array.array( 'B', h ) ) ) )
            sys.exit( 1 )

         # print ki, Sol.format1( array.array('B',h) )


   else:
      data = [ [ "abcde12345qwert7890-", "zxcvbASDFGNM<>?&*()_",
               [ [ 0x46, 0xb2, 0x45, 0x70, 0x2f, 0xfb, 0x19, 0xac, 0x60, 0x3f,
                   0x2e, 0x64, 0x07, 0x90, 0x90, 0x07, 0xa0, 0x3a, 0x94, 0x0b ],
                 [ 0xc7, 0x21, 0x7a, 0x48, 0x0a, 0xad, 0x16, 0x78, 0x38, 0xc9,
                   0x78, 0x17, 0xa7, 0xed, 0x14, 0xb1, 0x24, 0xda, 0xd4, 0xc5 ],
                 [ 0xe8, 0x30, 0xe7, 0x4f, 0x28, 0x2e, 0x77, 0x47, 0x6a, 0x8a,
                   0xd7, 0xed, 0xad, 0xca, 0x3c, 0xb3, 0xa5, 0xbb, 0x13, 0x09 ],
                 [ 0x73, 0x3e, 0x17, 0x1c, 0x49, 0x5b, 0x19, 0x01, 0x70, 0x7d,
                   0xd0, 0x46, 0x3b, 0xd3, 0x67, 0x98, 0x91, 0x85, 0x1d, 0x15 ]]],

               [ array.array('B', [0] * 20).tostring(), # pylint: disable=no-member
                 # pylint: disable-next=no-member
                 array.array('B', [0xff] * 20).tostring(),
               [ [ 0x37, 0xe1, 0x10, 0x57, 0x6d, 0xda, 0xcf, 0x7e, 0xb7, 0xe1,
                   0xe4, 0x4d, 0x19, 0x7a, 0x1b, 0x25, 0x21, 0x42, 0x02, 0x41 ],
                 [ 0x30, 0xaa, 0x83, 0x37, 0xc1, 0x83, 0xe4, 0xbb, 0xb1, 0x18,
                   0xa4, 0x8d, 0x7a, 0x6c, 0xff, 0x3a, 0xb2, 0x76, 0x7b, 0x37 ],
                 [ 0x11, 0x69, 0x85, 0xcd, 0x80, 0x02, 0x8a, 0xe3, 0x38, 0x46,
                   0x3a, 0x02, 0x30, 0x3b, 0xe1, 0x46, 0xdd, 0x14, 0x2e, 0x16 ],
                 [ 0xda, 0xc3, 0x51, 0xba, 0xc2, 0x18, 0x94, 0xcb, 0x23, 0xa3,
                   0xef, 0xdd, 0x4e, 0x26, 0x28, 0x91, 0x82, 0x61, 0xa2, 0x4b ]]],

               # pylint: disable-next=no-member
               [ array.array('B', [0xff] * 20).tostring(),
                 # pylint: disable-next=no-member
                 array.array('B', [0x00] * 20).tostring(),
               [ [ 0x29, 0xd7, 0x1c, 0xe4, 0x75, 0xfa, 0x0c, 0x23, 0x1e, 0x83,
                   0x68, 0xb4, 0x8a, 0xc8, 0x90, 0x27, 0x6e, 0x28, 0x48, 0x00 ],
                 [ 0x49, 0x00, 0x4f, 0xfc, 0xc3, 0xc7, 0x28, 0x53, 0x69, 0xfd,
                   0x7d, 0x87, 0xc2, 0xa1, 0x4d, 0xc5, 0x78, 0x22, 0x3b, 0x61 ],
                 [ 0x1b, 0xb3, 0xdf, 0xfd, 0xc1, 0x09, 0xef, 0xbf, 0x65, 0x68,
                   0xad, 0x21, 0x09, 0x7b, 0x56, 0x05, 0x6a, 0x70, 0x15, 0x7b ],
                 [ 0x97, 0x90, 0xe6, 0x0e, 0x41, 0xab, 0xe9, 0x9c, 0xdc, 0x32,
                   0xd7, 0xd7, 0xf5, 0x9b, 0x17, 0x20, 0xd0, 0x13, 0x82, 0x65 ]]],

               # pylint: disable-next=no-member
               [ array.array('B', [0xa5] * 20).tostring(),
                 # pylint: disable-next=no-member
                 array.array('B', [0x5a] * 20).tostring(),
               [ [ 0x09, 0x75, 0xac, 0x1e, 0x53, 0xf7, 0x09, 0x79, 0xc1, 0xb7,
                   0xf3, 0x34, 0x95, 0x5e, 0x91, 0x8e, 0x1f, 0x9a, 0x84, 0x35 ],
                 [ 0xc9, 0x66, 0xba, 0x49, 0x1b, 0xa7, 0x3d, 0xb1, 0x0b, 0xe2,
                   0x2a, 0xf5, 0x45, 0x68, 0x29, 0x97, 0x97, 0xdc, 0x0b, 0x28 ],
                 [ 0x2e, 0x2e, 0x82, 0x5f, 0xe6, 0xe5, 0x2b, 0xd7, 0x0e, 0x95,
                   0x5e, 0x45, 0x69, 0x08, 0xcf, 0xa3, 0xa5, 0x6f, 0x02, 0x08 ],
                 [ 0x51, 0x57, 0x34, 0xad, 0x27, 0x2e, 0x1d, 0x70, 0x5f, 0x46,
                   0x56, 0xb3, 0x82, 0x6a, 0xd6, 0x32, 0x2c, 0x53, 0xaa, 0xf6 ]]]
             ]
      for (M, N, result) in data:
         for ki in range( 4 ):
            Sol.set( handle, array.array('B', M+N) )
            Sol.call( handle, Sol.SolCmdAuthM, data=ki )
            H = Sol.get( handle, 20 )

            # pylint: disable-next=no-member
            if H.tostring() != array.array('B', result[ki]).tostring():
               print( "AuthM failure:\nchip: %s\n  sw: %s" %
                     ( Sol.format1( array.array( 'B', H ) ),
                       Sol.format1( array.array( 'B', result[ ki ] ) ) ) )
               sys.exit( 1 )

def testLoadG( handle, args ):
   if len(args) != 3:
      print( "usage: soldiag testLoadG idx g.akey" )
      sys.exit( 1 )

   idx = int( args[1] )
   fp = open( args[2] ) # pylint: disable=consider-using-with
   g = fp.read()

   if idx < 0 or idx > 127:
      # idx == 0 is always stored in the security chip
      print( "error: invalid index. must be [1..127]" )
      sys.exit( 1 )

   g = array.array( 'B', g[256 * idx : 256 * (idx + 1)] )

   Sol.call( handle, Sol.SolCmdSetupMfg )
   for j in range( keySize // chunkSize ):
      chunk = g[(j*chunkSize):((j+1)*chunkSize)]
      Sol.set( handle, chunk )
      Sol.call( handle, Sol.SolCmdSetData, data=j*chunkSize )
   Sol.call( handle, Sol.SolCmdSetG, sleep=90 )

def testLoadG0( handle ):
   Sol.call( handle, Sol.SolCmdSetG0 )

def testAuthG( handle ):
   """ First call testLoadG to get G, then testAuthG to run with that G.
       For this to work, the key must equal the G set in LoadG """

   message = array.array( 'B', "abcde12345fghij67890" * 11 )

   # Security Chip version
   Sol.call( handle, Sol.SolCmdStartG )

   for x in range( ( len( message ) + ( chunkSize - 1 ) ) // chunkSize ):
      Sol.set( handle, message[(x * chunkSize): (x + 1) * chunkSize] )
      Sol.call( handle, Sol.SolCmdAddEEPROM )

   Sol.call( handle, Sol.SolCmdAuthG )
   output = Sol.get( handle, 20 )

   # software version
   # note: for this to work, you need to set key to
   #       one of the G values.
   key = array.array( 'B',
                      [0]*20 )

   # pylint: disable-next=no-member
   v = hmac.new( key.tostring(), message.tostring(), hashlib.sha1 ).digest()

   if output != array.array('B', v):
      print( "output" )
      Sol.dump( output )
      print( "v" )
      Sol.dump( array.array('B', v) )
      sys.exit( 1 )

def testAuthN( handle, chip ):
   """ authenticate PCA
       hint: chipkey is in the seeprom """

   chip_pub = fetchSolKey( chip, "SolKey" )
   nonce = array.array( 'B', "abcde12345vwxyz67890" )

   # Security Chip portion
   Sol.call( handle, Sol.SolCmdSetupChip )
   Sol.set( handle, nonce )

   sig = array.array( 'B' )
   for x in range( 2048 // 8 // chunkSize ):
      Sol.call( handle, Sol.SolCmdAuthN, data=(x * chunkSize),
                    sleep=( (not x) * 60 ) )
      sig += Sol.get( handle, chunkSize )
   # Sol.dump( sig )

   e = struct.pack( "!LBBB", 3, 0x01, 0x00, 0x01 )
   n_len = struct.pack( "!LB", 0x101, 0 )

   try:
      Chip_pub = RSA.new_pub_key( ( e, n_len + chip_pub ) )
      assert Chip_pub.check_key() == 1, "Chip_pub key check"

      digest = hashlib.sha1( nonce ).digest()
      verify = Chip_pub.verify( digest, sig.tostring(), 'sha1' )
      assert verify == 1, "verify"
   except RSA.RSAError as e:
      print( "ERROR: RSA Error [%s]" % e )
      sys.exit( 1 )

def verifyPca( handle, chip ):
   """ Verify PCA.  mfg format must match ssmfg::genmfg() """

   (pca, sn, mac) = fetchBoardAttributes( chip )

   if hasattr( handle, 'solConfig' ):
      if handle.solConfig:
         if not handle.solConfig.hasBoardMac:
            mac = "None"
            mac = mac + "\012"

   chippub = fetchSolKey( chip, "SolKey" )

   formatStr = "1"
   mfg = formatStr + pca + sn + mac + chippub
   mfghash = hashlib.sha1( mfg ).digest()

   sig = fetchSolKey( chip, "SolMfg" )

   # should probably just include ARun_pub into the source code
   # when this is ported to C
   try:
      arun_pub = RSA.load_pub_key( ARunpubfile )
      doubledigest = hashlib.sha1( mfghash ).digest()
      verify = arun_pub.verify( doubledigest, sig )
      assert verify, "verify PCA"
   except RSA.RSAError as e:
      print( "ERROR: RSA Error [%s]" % e )
      sys.exit( 1 )

usage = """usage: soldiag [--chip <chipid>][--kgfile <kg.bin>] <command>
   soldiag testGenMV             (signing server)
   soldiag testAuthM             (signing server, EOS dut)
   soldiag testLoadG idx g.akey  (EOS dut)
   soldiag testLoadG0            (EOS dut)
   soldiag testAuthG             (EOS dut)
   soldiag testAuthN             (EOS dut)
   soldiag verifyPCA             (EOS dut)"

   Use 'solid list' to get a list of valid chipids. """

def main():
   parser = optparse.OptionParser( usage=usage )
   parser.add_option( "--verbose", action="store_true",
                      help="chatty" )
   parser.add_option( "--chip", action="store",
                      help="which solchip to access" )
   parser.add_option( "--kgfile", action="store", type="string",
                      help="key file" )
   parser.add_option( "--sysname", action="store", default="ar",
                      help="Sysdb sysname" )

   ( opt, args ) = parser.parse_args()

   if len(args) < 1:
      parser.error( "missing command" )

   if opt.kgfile:
      readKGFile( opt.kgfile )

   handle = initialize( opt )
   if args[0] == 'testGenMV':
      testGenMV( handle )
   elif args[0] == 'testAuthM':
      testAuthM( handle, opt.kgfile )
   elif args[0] == 'testLoadG':
      testLoadG( handle, args )
   elif args[0] == 'testLoadG0':
      testLoadG0( handle )
   elif args[0] == 'testAuthG':
      testAuthG( handle )
   elif args[0] == 'testAuthN':
      testAuthN( handle, opt.chip )
   elif args[0] == 'verifyPCA':
      verifyPca( handle, opt.chip )
   else:
      parser.error( usage )
   Sol.cleanup( handle )

if __name__ == "__main__":
   main()

sys.exit( 0 )


