I'm learning about how full nodes can send a Merkle root + a list of hashes, so that a light client can verify a transaction.
I couldn't find any good resources on how to implement it. I just knew what nodes I needed in order to verify a leaf, so I just found those nodes using DFS.
I'm sure there is a better way to construct the HashTree, find the list of nodes the light client needs, and verification algorithm. Therefore I would really appreciate it if you could review the algorithms and code and give me advice/guidance:
HashTree.java:
package core;
import static util.Bytes.SHA256;
import static util.Bytes.merge;
import java.util.Vector;
public class HashTree {
private final HashNode root;
public HashTree(byte[]... leaves) {
final Vector<HashNode> hashes = new Vector<>();
for (byte[] leaf : leaves) {
hashes.add(new HashNode(leaf, null, null, null));
}
root = construct(hashes);
}
public byte[] getRootHash() {
return root.hash;
}
public static byte[] getRootHash(byte[] hash, boolean oddHashIndex, Vector<byte[]> siblings) {
for (byte[] sibling : siblings) {
hash = (oddHashIndex) ? SHA256(merge(sibling, hash)) : SHA256(merge(hash, sibling));
oddHashIndex = !oddHashIndex;
for (byte b : hash) {
System.out.print(String.format("%02X", b));
}
System.out.println();
}
return hash;
}
/**
* Performs a depth-first search to find a leaf in this HashTree.
* If the leaf is found, construct a path with all the siblings
* needed to verify the leaf.
*
* @param leaf to find
* @return path, otherwise {@code null}
*/
public Vector<byte[]> authenticationPath(byte[] leaf) {
final Vector<HashNode> visited = new Vector<>();
final Vector<byte[]> path = new Vector<>();
if (!dfs(root, leaf, visited, path)) {
throw new RuntimeException("could not find the given leaf");
}
return path;
}
private boolean dfs(HashNode current, byte[] target, Vector<HashNode> visited, Vector<byte[]> path) {
boolean found = false;
if (current.hash == target) {
path.add(current.getSibling().hash);
return true;
}
visited.add(current);
if (current.left != null && !visited.contains(current.left) && !found) {
found = dfs(current.left, target, visited, path);
}
if (current.right != null && !visited.contains(current.right) && !found) {
found = dfs(current.right, target, visited, path);
}
if (found && current != root) {
path.add(current.getSibling().hash);
}
return found;
}
/**
* Constructs a new hash tree from the given leaves.
* @param hashes (leaves)
*/
private HashNode construct(Vector<HashNode> hashes) {
if (hashes == null || hashes.size() < 1) {
throw new IllegalArgumentException("no leaves given");
}
if (hashes.size() == 1) {
return hashes.firstElement();
}
if (hashes.size() % 2 != 0) {
hashes.add(hashes.lastElement());
}
final Vector<HashNode> parents = new Vector<>();
for (int i = 0; i < hashes.size() - 1; i += 2) {
final byte[] parentHash = SHA256(merge(hashes.get(i).hash, hashes.get(i + 1).hash));
final HashNode parent = new HashNode(parentHash, null, hashes.get(i), hashes.get(i + 1));
hashes.get(i).parent = parent;
hashes.get(i + 1).parent = parent;
parents.add(parent);
}
return construct(parents);
}
private static final class HashNode {
final byte[] hash;
HashNode parent;
final HashNode left;
final HashNode right;
private HashNode(byte[] hash, HashNode parent, HashNode left, HashNode right) {
this.hash = hash;
this.parent = parent;
this.left = left;
this.right = right;
}
HashNode getSibling() {
if (parent == null) {
return null;
}
if (parent.left == this) {
return parent.right;
} else {
return parent.left;
}
}
}
}
HashTreeTest.java:
package core;
import java.util.Vector;
import static util.Bytes.SHA256;
import static util.Bytes.merge;
import static org.junit.jupiter.api.Assertions.*;
class HashTreeTest {
@org.junit.jupiter.api.Test
void test() {
final byte[][] leaves = new byte[][] {
SHA256("ABC".getBytes()), // 0
SHA256("DEF".getBytes()), // 1
SHA256("GHI".getBytes()), // 2
SHA256("JKL".getBytes()), // 3
SHA256("MNO".getBytes()), // 4
SHA256("PQR".getBytes()), // 5
SHA256("STU".getBytes()), // 6
SHA256("VWX".getBytes()), // 7
SHA256("YZA".getBytes()), // 8
};
final byte[][] internal1 = new byte[][] {
SHA256(merge(leaves[0], leaves[1])), // 0
SHA256(merge(leaves[2], leaves[3])), // 1
SHA256(merge(leaves[4], leaves[5])), // 2
SHA256(merge(leaves[6], leaves[7])), // 3
SHA256(merge(leaves[8], leaves[8])), // 4
};
final byte[][] internal2 = new byte[][] {
SHA256(merge(internal1[0], internal1[1])), // 0
SHA256(merge(internal1[2], internal1[3])), // 1
SHA256(merge(internal1[4], internal1[4])), // 2
};
final byte[][] internal3 = new byte[][] {
SHA256(merge(internal2[0], internal2[1])), // 0
SHA256(merge(internal2[2], internal2[2])), // 1
};
final HashTree hashTree = new HashTree(leaves);
final byte[] expectedRootHash = SHA256(merge(internal3[0], internal3[1]));
final byte[] actualRootHash = hashTree.getRootHash();
equals(expectedRootHash, actualRootHash);
final Vector<byte[]> path = hashTree.authenticationPath(leaves[5]);
equals(path.get(0), leaves[4]);
equals(path.get(1), internal1[3]);
equals(path.get(2), internal2[0]);
equals(path.get(3), internal3[1]);
assertEquals(4, path.size());
equals(expectedRootHash, HashTree.getRootHash(leaves[5], true, path));
}
void equals(byte[] expected, byte[] actual) {
assertEquals(expected.length, actual.length);
for (int i = 0; i < expected.length; i++) {
assertEquals(expected[i], actual[i]);
}
}
}
Bytes.java
package util;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
public final class Bytes {
public static byte[] merge(byte[]... bytes) {
final ByteArrayOutputStream stream = new ByteArrayOutputStream();
try {
for (byte[] b : bytes) {
stream.write(b);
}
} catch (IOException e) {
System.err.println(e.getMessage());
}
return stream.toByteArray();
}
public static byte[] SHA256(byte[] bytes) {
try {
final MessageDigest digester = MessageDigest.getInstance("SHA-256");
return digester.digest(bytes);
} catch (NoSuchAlgorithmException e) {
System.err.println(e.getMessage());
return null;
}
}
}
```
current.hash == targetwhere both elements arebyte[]arrays, you meant to check if the two arrays contain the same elements in the same positions ? \$\endgroup\$