Dijkstra in Scala
Dijkstra’s algorithm is a fundamental graph algorithm, which allows to compute the shortest path from a source node to a target node in a directed, weighted graph, i.e. the path with the smallest accumulated weight. In fact, the algorithm does not only compute the shortest path from one node to another, but a shortest-path tree, which describes the shortest path from the root of the tree (the source node) to any node that is reachable from the source.
In an earlier post, I presented a concise implementation of Dijkstra’s algorithm in Clojure, but the algorithm presented there only computes the distance (the price of the shortest path) from the source to any other node, not the shortest-path tree. So let us now look how to implement the algorithm in Scala, this time time computing a shortest-path tree.
We first have to fix a type for our graphs. Since we consider directed, weighted graphs, where every edge carries a price, the simplest way to represent a graph is a function mapping nodes to a set of pairs of the form (node, price). Instead of a set of pairs, we can also use a map from nodes to prices, which allows us to easily compute the price of an edge given a successor node. Hence, let us start with the following type definition, which is parameterized by the node type N
.
type Graph[N] = N => Map[N, Int]
Hence, given an instance g
of type Graph[N]
and a node n
, the map returned by g(n)
maps successor nodes of n
to the price of the corresponding edge.
The output of Dijkstra’s algorithm is a tree, which can be described as a map from nodes to their predecessor in the tree. It would be nice to also include the distances from the source in the output, so let us also return a map from nodes to their distance. Hence, a possible signature for our implementation of Dijkstra’s algorithm is:
def dijkstra[N](g: Graph[N])(source: N):
(Map[N, Int], Map[N, N])
From this we can easily define a function that optionally returns the shortest path from a source node to a target node as follows:
def shortestPath[N](g: Graph[N])(source: N, target: N):
Option[List[N]] = {
val pred = dijkstra(g)(source)._2
if (pred.contains(target) || source == target)
Some(iterateRight(target)(pred.get))
else None
}
Here, iterateRight
is a function that builds up a list from the end by repeatedly applying a function returning an option until it returns None
:
def iterateRight[N](x: N)(f: N => Option[N]): List[N] = {
def go(x: N, acc: List[N]): List[N] = f(x) match {
case None => x :: acc
case Some(y) => go(y, x :: acc)
}
go(x, List.empty)
}
Since Scala is a hybrid language, which allows to write imperative code as well as functional code, it is perfectly fine to implement Dijkstra’s algorithm in an imperative fashion using mutable data structures. Such an implementation looks as follows:
def dijkstra1[N](g: Graph[N])(source: N):
(Map[N, Int], Map[N, N]) = {
val active = mutable.Set(source)
val res = mutable.Map(source -> 0)
val pred = mutable.Map.empty[N, N]
while (active.nonEmpty) {
val node = active.minBy(res)
active -= node
val cost = res(node)
for ((n, c) <- g(node)) {
val cost1 = cost + c
if (cost1 < res.getOrElse(n, Int.MaxValue)) {
active += n
res += (n -> cost1)
pred += (n -> node)
}
}
}
(res.toMap, pred.toMap)
}
This implementation is similar to the pseudocode given on Wikipedia, but omits the unnecessary initialization phase. Instead, nodes are added to the set of active nodes only when they are discovered.
What is the problem with this solution? First of all, Scala.Predef
defines Map
as an alias for scala.collection.immutable.Map
, so in the end we need to convert our mutable maps res
and pred
to immutable maps. Second, in functional programming we consider it good style to avoid loops and mutable data structures. So, let’s try to come up with a more functional solution. Instead of using mutable data structures and modifying them inside a loop, we can convert the loop to a tail-recursive function which we call with the new values of our (now immutable) data structures.
def dijkstra2[N](g: Graph[N])(source: N):
(Map[N, Int], Map[N, N]) = {
def go(active: Set[N], res: Map[N, Int], pred: Map[N, N]):
(Map[N, Int], Map[N, N]) =
if (active.isEmpty) (res, pred)
else {
val node = active.minBy(res)
val cost = res(node)
val neighbours = for {
(n, c) <- g(node) if
cost + c < res.getOrElse(n, Int.MaxValue)
} yield n -> (cost + c)
val active1 = active - node ++ neighbours.keys
val preds = neighbours mapValues (_ => node)
go(active1, res ++ neighbours, pred ++ preds)
}
go(Set(source), Map(source -> 0), Map.empty)
}
The code is a straightforward adaption of the previous imperative solution. Inside the go
function, we replaced the side-effecting for
loop by a for
expression whose result we use in the recursive call.
The code is looking good, but we used Set
’s minBy
method to identfy the active node with the least distance from the source. Since regular sets are unorderd, this method has to look at every node in the set in order to identify the one with the least distance; this becomes a problem if in a large proportion of calls to go
many nodes are active. In order to demonstrate the problem, let us test our code on some generic graphs. The following function generates a binary tree of the given depth where every leaf of the tree has an edge leading back to the root, so that the graph becomes strongly connected:
def tree(depth: Int): Graph[List[Boolean]] = {
case x if x.length < depth =>
Map((true :: x) -> 1, (false :: x) -> 2)
case x if x.length == depth => Map(Nil -> 1)
case _ => Map.empty
}
Since going to the left successors costs 1 but going to the right successor costs 2, the shortest path from List(true)
, the left successor of the root Nil
, going back to the root should go via the left branch of the tree, so that the distance from List(true)
to Nil
should be equal to the depth of the tree. We can confirm this in the REPL.
scala> val t = tree(10)
t: Graph[List[Boolean]] = <function1>
scala> dijkstra1(t)(List(true))._1(Nil)
res1: Int = 10
scala> dijkstra2(t)(List(true))._1(Nil)
res1: Int = 10
Now, if you choose greater numbers for the depth, you will notice that both implementations will slow down very quickly. For instance, on my machine dijkstra1
and dijkstra2
return on tree(14)
only after around 30 seconds and on tree(15)
only after more than two minutes!
How can we make our code run faster? In the previous post, I presented priority maps and their implementation in Scala. Priority maps can be used like regular maps but offer fast access to entries with a small value. In particular, we can call head
on a priority map to access the entry with the smallest value in (almost) constant time. In order to profit from priority maps in our implementation of Dijkstra’s algorithm, we only need to adapt the type of active
and replace the call to minBy
by a call to head
. Since nodes in the priority map also carry their distance from the source, we can also delay putting a node into res
until it is removed from the priority map:
def dijkstra3[N](g: Graph[N])(source: N):
(Map[N, Int], Map[N, N]) = {
def go(active: PriorityMap[N, Int], res: Map[N, Int], pred: Map[N, N]):
(Map[N, Int], Map[N, N]) =
if (active.isEmpty) (res, pred)
else {
val (node, cost) = active.head
val neighbours = for {
(n, c) <- g(node) if !res.contains(n) &&
cost + c < active.getOrElse(n, Int.MaxValue)
} yield n -> (cost + c)
val preds = neighbours mapValues (_ => node)
go(active.tail ++ neighbours, res + (node -> cost), pred ++ preds)
}
go(PriorityMap(source -> 0), Map.empty, Map.empty)
}
Now, running dijkstra3
on tree(15)
takes less than a second, and we can even run the code on tree(20)
—a graph which contains more than a million nodes—in less than 30 seconds.
If you want to play with this code, it is available on GitHub.