After finding out that my previous implementations are incorrect, I decided to give it another try. I relied on this post.
(The entire project resides in this GitHub repository. Contains some unit tests that are not included in this post.)
Code
com.github.coderodde.pathfinding.BidirectionalDijkstrasAlgorithm.java:
package com.github.coderodde.pathfinding;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
/**
* This class implements a bidirectional Dijkstra's algorithm.
*
* @param <N> the actual graph node type.
* @param <W> the value type of arc weights.
*/
public final class BidirectionalDijkstrasAlgorithm<N, W> {
/**
* Searches for a shortest {@code source/target} path. Throws an
* {@link IllegalStateException} if the target node is not reachable from
* the source node.
*
* @param source the source node.
* @param target the target node.
* @param childrenExpander the node expander generating child nodes.
* @param parentsExpander the node expander generating parent nodes.
* @param weightFunction the weight function of the graph.
* @param scoreComparator the comparator for comparing weights/node
* g-scores.
*
* @return the shortest path.
*/
public List<N> findShortestPath(N source,
N target,
NodeExpander<N> childrenExpander,
NodeExpander<N> parentsExpander,
WeightFunction<N, W> weightFunction,
Comparator<W> scoreComparator) {
if (source.equals(target)) {
// We need to handle this special case, since the actual algorithm
// cannot deal with it.
return Arrays.asList(target);
}
Queue<HeapNodeWrapper<N, W>> queueF = new PriorityQueue<>();
Queue<HeapNodeWrapper<N, W>> queueB = new PriorityQueue<>();
Map<N, W> distancesF = new HashMap<>();
Map<N, W> distancesB = new HashMap<>();
Map<N, N> parentsF = new HashMap<>();
Map<N, N> parentsB = new HashMap<>();
Set<N> settledF = new HashSet<>();
Set<N> settledB = new HashSet<>();
queueF.add(new HeapNodeWrapper<>(
weightFunction.getZero(),
source,
scoreComparator));
queueB.add(new HeapNodeWrapper<>(
weightFunction.getZero(),
target,
scoreComparator));
distancesF.put(source, weightFunction.getZero());
distancesB.put(target, weightFunction.getZero());
parentsF.put(source, null);
parentsB.put(target, null);
W mu = weightFunction.getInfinity();
N touchNodeF = null;
N touchNodeB = null;
while (!queueF.isEmpty() && !queueB.isEmpty()) {
N currentNodeF = queueF.remove().getNode();
N currentNodeB = queueB.remove().getNode();
settledF.add(currentNodeF);
settledB.add(currentNodeB);
for (N childNode : childrenExpander.expand(currentNodeF)) {
if (settledF.contains(childNode)) {
continue;
}
if (!distancesF.containsKey(childNode) ||
scoreComparator.compare(
distancesF.get(childNode),
weightFunction.sum(
distancesF.get(currentNodeF),
weightFunction.getWeight(currentNodeF,
childNode))) > 0) {
W tentativeDistance =
weightFunction.sum(
distancesF.get(currentNodeF),
weightFunction.getWeight(currentNodeF,
childNode));
distancesF.put(childNode, tentativeDistance);
parentsF.put(childNode, currentNodeF);
queueF.add(new HeapNodeWrapper<>(tentativeDistance,
childNode,
scoreComparator));
}
if (settledB.contains(childNode)) {
W shortestPathUpperBound =
weightFunction.sum(
distancesF.get(currentNodeF),
weightFunction.getWeight(currentNodeF,
childNode),
distancesB.get(childNode));
if (scoreComparator.compare(mu,
shortestPathUpperBound) > 0) {
mu = shortestPathUpperBound;
touchNodeF = currentNodeF;
touchNodeB = childNode;
}
}
}
for (N parentNode : parentsExpander.expand(currentNodeB)) {
if (settledB.contains(parentNode)) {
continue;
}
if (!distancesB.containsKey(parentNode) ||
scoreComparator.compare(
distancesB.get(parentNode),
weightFunction.sum(
distancesB.get(currentNodeB),
weightFunction.getWeight(parentNode,
currentNodeB))) > 0) {
W tentativeDistance =
weightFunction.sum(
distancesB.get(currentNodeB),
weightFunction.getWeight(parentNode,
currentNodeB));
distancesB.put(parentNode, tentativeDistance);
parentsB.put(parentNode, currentNodeB);
queueB.add(new HeapNodeWrapper<>(tentativeDistance,
parentNode,
scoreComparator));
}
if (settledF.contains(parentNode)) {
W shortestPathUpperBound =
weightFunction.sum(
distancesF.get(parentNode),
weightFunction.getWeight(parentNode,
currentNodeB),
distancesB.get(currentNodeB));
if (scoreComparator.compare(mu,
shortestPathUpperBound) > 0) {
mu = shortestPathUpperBound;
touchNodeF = parentNode;
touchNodeB = currentNodeB;
}
}
}
if (distancesF.containsKey(currentNodeF) &&
distancesB.containsKey(currentNodeB) &&
scoreComparator.compare(
weightFunction.sum(
distancesF.get(currentNodeF),
distancesB.get(currentNodeB)),
mu) > 0) {
return tracebackPath(touchNodeF,
touchNodeB,
parentsF,
parentsB);
}
}
throw new IllegalStateException(
"The target node is not reachable from the source node.");
}
private static <N> List<N> tracebackPath(N touchNodeF,
N touchNodeB,
Map<N, N> parentsF,
Map<N, N> parentsB) {
List<N> path = new ArrayList<>();
N node = touchNodeF;
while (node != null) {
path.add(node);
node = parentsF.get(node);
}
Collections.reverse(path);
node = touchNodeB;
while (node != null) {
path.add(node);
node = parentsB.get(node);
}
return path;
}
}
com.github.coderodde.pathfinding.DijkstrasAlgorithm.java:
package com.github.coderodde.pathfinding;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
/**
* This class implements the (unidirectional) Dijkstra's algorithm.
*
* @param <N> the actual graph node type.
* @param <W> the weight value type.
*/
public final class DijkstrasAlgorithm<N, W> {
/**
* Finds the shortest {@code source/target} path or throws an
* {@link IllegalStateException} if the target node is not reachable from
* the source node.
*
* @param source the source node.
* @param target the target node.
* @param childrenExpander the children expander.
* @param weightFunction the graph weight function.
* @param scoreComparator the score comparator.
*
* @return the shortest path, if any exist.
*/
public List<N> findShortestPath(N source,
N target,
NodeExpander<N> childrenExpander,
WeightFunction<N, W> weightFunction,
Comparator<W> scoreComparator) {
Queue<HeapNodeWrapper<N, W>> open = new PriorityQueue<>();
Map<N, W> distanceMap = new HashMap<>();
Map<N, N> parentMap = new HashMap<>();
Set<N> closed = new HashSet<>();
open.add(new HeapNodeWrapper<>(
weightFunction.getZero(),
source,
scoreComparator));
distanceMap.put(source, weightFunction.getZero());
parentMap.put(source, null);
while (!open.isEmpty()) {
N currentNode = open.remove().getNode();
if (currentNode.equals(target)) {
return tracebackSolution(target, parentMap);
}
closed.add(currentNode);
for (N childNode : childrenExpander.expand(currentNode)) {
if (closed.contains(childNode)) {
continue;
}
if (!distanceMap.containsKey(childNode)) {
W tentativeDistance =
weightFunction.sum(
distanceMap.get(currentNode),
weightFunction.getWeight(currentNode,
childNode));
distanceMap.put(childNode, tentativeDistance);
parentMap.put(childNode, currentNode);
open.add(new HeapNodeWrapper<>(tentativeDistance,
childNode,
scoreComparator));
} else {
W tentativeDistance =
weightFunction.sum(
distanceMap.get(currentNode),
weightFunction.getWeight(currentNode,
childNode));
if (scoreComparator.compare(distanceMap.get(childNode), tentativeDistance) > 0) {
distanceMap.put(childNode, tentativeDistance);
parentMap.put(childNode, currentNode);
open.add(new HeapNodeWrapper<>(tentativeDistance,
childNode,
scoreComparator));
}
}
}
}
throw new IllegalStateException(
"Target not reachable from the source.");
}
private static <N> List<N> tracebackSolution(N target, Map<N, N> parentMap) {
List<N> path = new ArrayList<>();
N node = target;
while (node != null) {
path.add(node);
node = parentMap.get(node);
}
Collections.reverse(path);
return path;
}
}
com.github.coderodde.pathfinding.HeapNodeWrapper.java:
package com.github.coderodde.pathfinding;
import java.util.Comparator;
final class HeapNodeWrapper<N, W> implements Comparable<HeapNodeWrapper<N, W>> {
private final W score;
private final N node;
private final Comparator<W> scoreComparator;
HeapNodeWrapper(W score,
N node,
Comparator<W> scoreComparator) {
this.score = score;
this.node = node;
this.scoreComparator = scoreComparator;
}
N getNode() {
return node;
}
@Override
public int compareTo(HeapNodeWrapper<N, W> o) {
return scoreComparator.compare(this.score, o.score);
}
}
com.github.coderodde.pathfinding.NodeExpander.java:
package com.github.coderodde.pathfinding;
import java.util.Collection;
/**
* This interface defines the API for all the node expanders.
*
* @param <N> the actual type of the nodes.
*/
public interface NodeExpander<N> {
/**
* Returns the expansion view of the input node.
*
* @param node the node to expand.
* @return the collection of "next" nodes to consider in search.
*/
Collection<N> expand(N node);
}
com.github.coderodde.pathfinding.WeightFunction.java:
package com.github.coderodde.pathfinding;
/**
* This interface defines the API for graph weight functions.
*
* @param <N> the actual graph node type.
* @param <W> the type of the weight values.
*/
public interface WeightFunction<N, W> {
/**
* Returns the weight of the arc {@code (tail, head)}.
*
* @param tail the starting node of the arc.
* @param head the ending node of the arc.
* @return the weight of the input arc.
*/
W getWeight(N tail, N head);
/**
* Returns the value of type {@code W} representing zero.
*
* @return the zero value.
*/
W getZero();
/**
* Returns the largest representable weight.
*
* @return the largest weight.
*/
W getInfinity();
/**
* Returns the sum of {@code w1} and {@code w2}.
*
* @param w1 the first weight value.
* @param w2 the second weight value.
*
* @return the sum of the two input weights.
*/
W sum(W w1, W w2);
/**
* Returns the sum of three input weights. We need this method primarily for
* bidirectional Dijkstra's algorithm.
*
* @param w1 the first weight.
* @param w2 the second weight.
* @param w3 the third weight.
* @return the sum of the three input weights.
*/
default W sum(W w1, W w2, W w3) {
return sum(w1, sum(w2, w3));
}
}
com.github.coderodde.pathfinding.benchmark.Benchmark.java:
package com.github.coderodde.pathfinding.benchmark;
import com.github.coderodde.pathfinding.BidirectionalDijkstrasAlgorithm;
import com.github.coderodde.pathfinding.DijkstrasAlgorithm;
import com.github.coderodde.pathfinding.NodeExpander;
import com.github.coderodde.pathfinding.WeightFunction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
final class Benchmark {
private static final int NUMBER_OF_NODES = 100_000;
private static final int NUMBER_OF_ARCS = 1_000_000;
public static void main(String[] args) {
long seed = parseSeed(args);
System.out.println("Seed = " + seed);
Random random = new Random(seed);
long startTime = System.currentTimeMillis();
GraphData graphData = getRandomGraph(NUMBER_OF_NODES,
NUMBER_OF_ARCS,
random);
System.out.printf("Built the graph in %d milliseconds.\n",
System.currentTimeMillis() - startTime);
DirectedGraphNode source = graphData.getRandonNode(random);
DirectedGraphNode target = graphData.getRandonNode(random);
System.out.printf("Source node: %s\n", source);
System.out.printf("Target node: %s\n", target);
DijkstrasAlgorithm<DirectedGraphNode, Float> pathfinderDijkstra =
new DijkstrasAlgorithm<>();
BidirectionalDijkstrasAlgorithm<DirectedGraphNode, Float>
pathfinderBidirectionalDijkstra =
new BidirectionalDijkstrasAlgorithm<>();
NodeExpander<DirectedGraphNode> childNodeExpander =
new DirectedGraphNodeChildrenExpander();
NodeExpander<DirectedGraphNode> parentNodeExpander =
new DirectedGraphNodeParentsExpander();
DirectedGraphWeightFunction weightFunction =
new DirectedGraphWeightFunction();
startTime = System.currentTimeMillis();
List<DirectedGraphNode> pathDijkstra =
pathfinderDijkstra.findShortestPath(
source,
target,
childNodeExpander,
weightFunction,
Float::compare);
System.out.printf("Dijkstra's algorithm in %d milliseconds.\n",
System.currentTimeMillis() - startTime);
startTime = System.currentTimeMillis();
List<DirectedGraphNode> pathBidirectionalDijkstra =
pathfinderBidirectionalDijkstra.findShortestPath(
source,
target,
childNodeExpander,
parentNodeExpander,
weightFunction,
Float::compare);
System.out.printf(
"Bidirectional Dijkstra's algorithm in %d milliseconds.\n",
System.currentTimeMillis() - startTime);
boolean pathsAreEqual = pathDijkstra.equals(pathBidirectionalDijkstra);
if (pathsAreEqual) {
System.out.println("Paths agree:");
for (DirectedGraphNode node : pathDijkstra) {
System.out.println(node);
}
System.out.printf(
"Path cost: %.3f\n",
computePathCost(pathDijkstra, weightFunction));
} else {
System.out.println("Paths diagree!");
System.out.println("Dijkstra's algorithm's path:");
for (DirectedGraphNode node : pathDijkstra) {
System.out.println(node);
}
System.out.printf("Dijkstra's path cost: %.3f\n",
computePathCost(pathDijkstra, weightFunction));
System.out.println("Bidirectional Dijkstra's algorithm's path:");
for (DirectedGraphNode node : pathBidirectionalDijkstra) {
System.out.println(node);
}
System.out.printf("Bidirectional Dijkstra's path cost: %.3f\n",
computePathCost(pathBidirectionalDijkstra,
weightFunction));
}
}
private static long parseSeed(String[] args) {
if (args.length == 0) {
return System.currentTimeMillis();
}
try {
return Long.parseLong(args[0]);
} catch (NumberFormatException ex) {
System.err.printf("WARNING: Could not parse %s as a long value.",
args[0]);
return System.currentTimeMillis();
}
}
private static float computePathCost(
List<DirectedGraphNode> path,
DirectedGraphWeightFunction weightFunction) {
float cost = 0.0f;
for (int i = 0; i < path.size() - 1; i++) {
DirectedGraphNode tail = path.get(i);
DirectedGraphNode head = path.get(i + 1);
float arcWeight = weightFunction.getWeight(tail, head);
cost += arcWeight;
}
return cost;
}
private static final class GraphData {
private final List<DirectedGraphNode> graphNodes;
private final DirectedGraphWeightFunction weightFunction;
GraphData(List<DirectedGraphNode> graphNodes,
DirectedGraphWeightFunction weightFunction) {
this.graphNodes = graphNodes;
this.weightFunction = weightFunction;
}
DirectedGraphNode getRandonNode(Random random) {
return choose(graphNodes, random);
}
}
private static final GraphData
getRandomGraph(int nodes, int edges, Random random) {
List<DirectedGraphNode> graph = new ArrayList<>(nodes);
Set<Arc> arcs = new HashSet<>(edges);
for (int i = 0; i < nodes; i++) {
graph.add(new DirectedGraphNode());
}
while (arcs.size() < edges) {
DirectedGraphNode tail = choose(graph, random);
DirectedGraphNode head = choose(graph, random);
Arc arc = new Arc(tail, head);
arcs.add(arc);
}
DirectedGraphWeightFunction weightFunction =
new DirectedGraphWeightFunction();
for (Arc arc : arcs) {
DirectedGraphNode tail = arc.getTail();
DirectedGraphNode head = arc.getHead();
float weight = 100.0f * random.nextFloat();
tail.addChild(head, weight);
}
return new GraphData(graph, weightFunction);
}
private static <T> T choose(List<T> list, Random random) {
return list.get(random.nextInt(list.size()));
}
private static final class Arc {
private final DirectedGraphNode tail;
private final DirectedGraphNode head;
Arc(DirectedGraphNode tail, DirectedGraphNode head) {
this.tail = tail;
this.head = head;
}
DirectedGraphNode getTail() {
return tail;
}
DirectedGraphNode getHead() {
return head;
}
@Override
public int hashCode() {
return Objects.hash(tail, head);
}
@Override
public boolean equals(Object o) {
Arc arc = (Arc) o;
return tail.equals(arc.tail) &&
head.equals(arc.head);
}
}
}
final class DirectedGraphNode {
private static int nodeIdCounter = 0;
private final int id;
private final Map<DirectedGraphNode, Float> outgoingArcs =
new HashMap<>();
private final Map<DirectedGraphNode, Float> incomingArcs =
new HashMap<>();
DirectedGraphNode() {
this.id = nodeIdCounter++;
}
void addChild(DirectedGraphNode child, Float weight) {
outgoingArcs.put(child, weight);
child.incomingArcs.put(this, weight);
}
List<DirectedGraphNode> getChildren() {
return new ArrayList<>(outgoingArcs.keySet());
}
List<DirectedGraphNode> getParents() {
return new ArrayList<>(incomingArcs.keySet());
}
Float getWeightTo(DirectedGraphNode headNode) {
return outgoingArcs.get(headNode);
}
@Override
public String toString() {
return String.format("[DirectedGraphNode id = %d]", id);
}
@Override
public int hashCode() {
return id;
}
@Override
public boolean equals(Object obj) {
DirectedGraphNode other = (DirectedGraphNode) obj;
return this.id == other.id;
}
}
class DirectedGraphWeightFunction
implements WeightFunction<DirectedGraphNode, Float> {
@Override
public Float getWeight(DirectedGraphNode tail, DirectedGraphNode head) {
return tail.getWeightTo(head);
}
@Override
public Float getZero() {
return 0.0f;
}
@Override
public Float getInfinity() {
return Float.POSITIVE_INFINITY;
}
@Override
public Float sum(Float w1, Float w2) {
return w1 + w2;
}
}
class DirectedGraphNodeChildrenExpander
implements NodeExpander<DirectedGraphNode> {
@Override
public List<DirectedGraphNode> expand(DirectedGraphNode node) {
return node.getChildren();
}
}
class DirectedGraphNodeParentsExpander
implements NodeExpander<DirectedGraphNode> {
@Override
public List<DirectedGraphNode> expand(DirectedGraphNode node) {
return node.getParents();
}
}
Typical output
Seed = 1705171998017
Built the graph in 1768 milliseconds.
Source node: [DirectedGraphNode id = 80226]
Target node: [DirectedGraphNode id = 33520]
Dijkstra's algorithm in 1056 milliseconds.
Bidirectional Dijkstra's algorithm in 31 milliseconds.
Paths agree:
[DirectedGraphNode id = 80226]
[DirectedGraphNode id = 35320]
[DirectedGraphNode id = 77598]
[DirectedGraphNode id = 93003]
[DirectedGraphNode id = 34031]
[DirectedGraphNode id = 32260]
[DirectedGraphNode id = 53773]
[DirectedGraphNode id = 53078]
[DirectedGraphNode id = 35871]
[DirectedGraphNode id = 15879]
[DirectedGraphNode id = 79948]
[DirectedGraphNode id = 31828]
[DirectedGraphNode id = 10811]
[DirectedGraphNode id = 44856]
[DirectedGraphNode id = 33520]
Path cost: 123,482
Critique request
As always, I would like to hear whatever comes to mind.