Hindley-Milner type inference in Rust

I've recently been writing a compiler in Rust. One part of this is type inference. I've written up some notes (as a kind of "what I wish I'd known" thing), in the hope that it will be of use to someone.

The basic problem of type inference is this: we don't want to make the programmer write out all the types, so we "infer" the ones that are (to the human mind) 'obvious'. For example, in the following program not all the type annotations are necessary.

function f(x: int) -> int {
	return x * x;
}

let y: int = 12;
let x: int = f(y);

We could definitely remove the type annotation from the y (because we can work out that y must be an integer due to the fact that we assign the literal 12 to it)!

function f(x: int) -> int {
	return x * x;
}

let y = 12;
let x: int = f(y);

For this specific program, we can also remove the annotation from the x (we can work it out, because we worked out that y is an integer, and we know that when we multiply an integer by another integer we get an integer.

function f(x: int) -> int {
	return x * x;
}

let y = 12;
let x = f(y);

We can do a similar thing for the function f. Because we only ever pass integers into it, we can remove the int annotation. Whether this is actually desirable is questionable (bidirectional type checking, for example, explicitly keeps type annotations for functions).

function f(x) -> int {
	return x * x;
}

let y = 12;
let x = f(y);

And we can also remove the return value (because return int*int implies int as the return type).

function f(x) {
	return x * x
}

let y = 12;
let x = f(y);

So, in essence, type inference is about getting from:

function f(x) {
	return x * x
}

let y = 12;
let x = f(y);

to

function f(x: int) -> int {
	return x * x;
}

let y: int = 12;
let x: int = f(y);

Of course, this is not (in my mind) an easy problem to solve! My instinct was to use a tree representing the program state and traverse it, inferring the types as we go. It turns out that this is definitely possible, but there's a much simpler way (in my opinion) to do this which is to separate building the set of constraints and try to solve it. A constraint just states that two types have to be equal to each other; every time we removed a type in getting from the typed to the untyped program (above), we were effectively dealing with a constraint. A constraint set representing the types in the program above, might look something like this:

y = Int
x (the variable) = f (the return type of f)
x (from f(x)) = y
x (from f(x)) = f (the return type of f)

But how do we distinguish between how the different names are bound - e.g. as in our example x the argument to f(x) is not the same variable as x the variable - after all? One easy way of doing this is to assign a unique identifier (which is not the name of the program) to every use of each variable (such that variables with the same name, but which are bound to different things, for instance because although they have the same name they are in different scopes, have different identifiers)

Once we have the constraint set, it's not too hard to solve it. This kind of problem is solved using something very similar to Gaussian elimination (except solving constraints rather than systems of linear equations). The algorithm in question is called a unification algorithm, such as this one which I extracted from a compiler I wrote and modified a bit to make it more readable.

The general method for solving the constraints is:

  1. remove the first item from the constraint set (if the set is empty, then return the results)
  2. generate a substitution for this. There are a couple of cases:
    • if we have a constraint variable=type we substitute the variable with the type
    • if we have a constraint variableY=variableX we substitute variableY for variableX
    • if we have a constraint type=type we check if the two types are equal – if they are, then we generate no substitution – if they are not the same, then there is a type error in the program, and we abort
  3. we use this substitution to work out the value of the type – for this maintain a map containing all the variables and the values they correspond to
    • if we are applying the substitution "variable=type" then insert variable -> type into the solutions map
    • if we are applying the substitution replace variableY with variableX then insert variableY->variableX into the map
  4. we apply this substitution to the rest of the set, and then call unify recursively (with the new data)
y = Int
x (the variable) = f (the return type of f)
x (from f(x)) = y
x (from f(x)) = f (the return type of f)

In the case of the program above this would run something like this.

  1. Substitute y=Int in the rest of the set, and add y->Int to our solutions.
solutions = {
	y -> int
}
constraints = {
	x (the variable) = f (the return type of f)
	x (from f(x)) = Int
	x (from f(x)) = f (the return type of f)
}
  1. Substitute x (the variable) with f (the return type of f)
solutions = {
	y -> int
	x (the variable) -> f (the return type of f)
}
constraints = {
	x (from f(x)) = Int
	x (from f(x)) = f (the return type of f)
}
  1. substitute x(from f(x)) = Int
solutions = {
	y -> int
	x (the variable) -> f (the return type of f)
	x (from f(x)) = Int
}
constraints = {
	Int = f (the return type of f)
}
  1. substitute f (the return type of x) = Int
solutions = {
	y -> int
	x (the variable) -> f (the return type of f)
	x (from f(x)) = Int
  f (the return type of f) = Int
}
constraints = {
}
  1. Done!

To find something from the solution set, I wrote a small algorithm something like:

  1. find the item in the map - return None if it could not be found
  2. if the item corresponds to a concrete type (e.g. Int) then return Some(Int)
  3. otherwise recursively search for the item that is pointed to by the one we just found

further reading