unification (sort-of in rust)

I've recently been ~~attempting~~ failing to write a compiler in Rust (I keep getting stuck on name resolution for modules, and scopes/namespaces – if you have any good resources, please do point me to them!) One part of this that I did (eventually) manage to get working was type inference. I've written up a tutorial-esque thing (you know, the "what I wish I'd known" thing), and I hope that it will be of use to someone.

The basic problem of type inference is this: we don't want to make the programmer write out all the types, so we "infer" the ones that are (too the human mind) 'obvious'. For example, in the following program not all the type annotations are necessary.

function f(x: int) -> int {
    return x * x;
}

let y: int = 12;
let x: int = f(y);

We could definitely remove the type annotation from the y (because we can work out that y must be an integer due to the fact that we assign the literal 12 to it)!

function f(x: int) -> int {
    return x * x;
}

let y = 12;
let x: int = f(y);

For this specific program, we can also remove the annotation from the x (we can work it out, because we worked out that y is an integer, and we know that when we multiply an integer by another integer we get an integer.

function f(x: int) -> int {
    return x * x;
}

let y = 12;
let x = f(y);

We can do a similar thing for the function f. Because we only ever pass integers into it, we can remove the int annotation. Whether this is actually desirable is questionable (bidirectional type checking, for example, explicitly keeps type annotations for functions).

function f(x) -> int {
    return x * x;
}

let y = 12;
let x = f(y);

And we can also remove the return value (because return int*int implies int as the return type).

function f(x) {
    return x * x
}

let y = 12;
let x = f(y);

So, in essence, type inference is about getting from:

function f(x) {
    return x * x
}

let y = 12;
let x = f(y);

to

function f(x: int) -> int {
    return x * x;
}

let y: int = 12;
let x: int = f(y);

Of course, this is not (in my mind) an easy problem to solve! My instinct was to use a tree representing the program state and traverse it, infering the types as we go. It turns out that this is definitely possible, but there's a much simpler way (in my opinion) to do this – we separating out finding "constraints" and solving them. A "constraint" means – reason why something needs to be a type. For example, every time we removed a type in getting from the typed to the untyped program, we were effectively dealing with a constraint. A sufficient constraint set for the types above, might look something like this:

y = Int
x (the variable) = f (the return type of f)
x (from f(x)) = y
x (from f(x)) = f (the return type of f)

But how do we distinguish between the different names in use – x from f(x) doesn't equal x the variable, after all? I'm actually not sure, but the ~~approach~~ guess I took was something like this – assuming you've constructed your syntax tree not dissimiliarly to that of the Rust compiler (see this page of the rustc-dev-guide for details) and every identifier (and expression, function return, etc ) has a unique tag (e.g. a usize) we can use these as the constraints. Once I've worked out how to handle name resolution, I'll write up my findings in the hope that they can be of help to somebody else.

Once we have the constraint set, it's not too hard to solve it. We write a unification algorithm, such as the one below (I hope it's correct :D).

/// Unify a set of constraints (i.e. solve them, if that is possible)
fn unify(set: HashSet<Constraint>, mut solved: TyEnv) -> Result<TyEnv, TyCheckError> {
    if set.is_empty() {
        return Ok(solved);
    }

    let mut iter = set.into_iter();

    let next = if let Some(next) = iter.next() {
        next
    } else {
        return Ok(solved);
    };

    let u = match next {
        Constraint::IdToTy { id, ty } => Some(Substitution::ConcreteForX(ty, id)),
        Constraint::IdToId { id, to } => {
            if id != to {
                Some(Substitution::XforY(id, to))
            } else {
                None
            }
        }
        Constraint::TyToTy { ty, to } => {
            if ty != to {
                return Err(TyCheckError::TypeMismatch);
            } else {
                None
            }
        }
    };

    if let Some(u) = u {
        // add the substitution to the list of substitutions
        solved.feed_substitution(u);

        // apply the newly generated substitution to the rest of the set
        let new_set: HashSet<Constraint> = iter
            .map(|constraint| match (u, constraint) {
                // wherever we see y, replace with x
                (Substitution::XforY(x, y), Constraint::IdToTy { id, ty }) => Constraint::IdToTy {
                    id: if id == y { x } else { id },
                    ty,
                },
                // wherever we see y, replace with x
                (Substitution::XforY(x, y), Constraint::IdToId { id, to }) => Constraint::IdToId {
                    id: if id == y { x } else { id },
                    to: if to == y { x } else { to },
                },
                (Substitution::ConcreteForX(sub_with, x), Constraint::IdToTy { id, ty }) => {
                    if id == x {
                        Constraint::TyToTy {
                            ty: sub_with,
                            to: ty,
                        }
                    } else {
                        Constraint::IdToTy { id, ty }
                    }
                }
                (Substitution::ConcreteForX(sub_with, x), Constraint::IdToId { id, to }) => {
                    if id == x {
                        Constraint::IdToTy {
                            id: to,
                            ty: sub_with,
                        }
                    } else if to == x {
                        Constraint::IdToTy { id, ty: sub_with }
                    } else {
                        Constraint::IdToId { id, to }
                    }
                }
                // these substitutions cannot be applied to the constraint set
                (Substitution::XforY(_, _), constraint)
                | (Substitution::ConcreteForX(_, _), constraint) => constraint,
            })
            .collect();

        unify(new_set, solved)
    } else {
        unify(iter.collect(), solved)
    }
}

Solving these constraints shouldn't be too tricky. What we do is:

  1. remove the first item from the constraint set (if the set is empty, then return the results)
  2. generate a substitution for this. There are a couple of cases:
    • if we have a constraint variable=type we substitute the variable with the type
    • if we have a constraint variableY=variableX we substitute variableY for variableX
    • if we have a constraint type=type we check if the two types are equal – if they are, then we generate no substitution – if they are not the same, then there is a type error in the program, and we abort
  3. we use this substitution to work out the value of the type – for this maintain a map containing all the variables and the values they correspond to
    • if we are applying the substitution "variable=type" then insert variable -> type into the solutions map
    • if we are applying the substitution replace variableY with variableX then insert variableY->variableX into the map
  4. we apply this substitution to the rest of the set, and then call unify recursively (with the new data)

Applying this to

y = Int
x (the variable) = f (the return type of f)
x (from f(x)) = y
x (from f(x)) = f (the return type of f)

would run something like this.

  1. Substitute y=Int in the rest of the set, and add y->Int to our solutions.
solutions = {
    y -> int
}
constraints = {
    x (the variable) = f (the return type of f)
    x (from f(x)) = Int
    x (from f(x)) = f (the return type of f)
}
  1. substitute x (the variable) with f (the return type of f)
solutions = {
    y -> int
    x (the variable) -> f (the return type of f)
}
constraints = {
    x (from f(x)) = Int
    x (from f(x)) = f (the return type of f)
}
  1. substitute x(from f(x)) = Int
solutions = {
    y -> int
    x (the variable) -> f (the return type of f)
    x (from f(x)) = Int
}
constraints = {
    Int = f (the return type of f)
}
  1. substitute f (the return type of x) = Int
solutions = {
    y -> int
    x (the variable) -> f (the return type of f)
    x (from f(x)) = Int
  f (the return type of f) = Int
}
constraints = {
}
  1. done!

To find something from the solution set, I wrote a small algorithm something like:

  1. find the item in the map – return None if it could not be found
  2. if the item corresponds to a concrete type (e.g. Int) then return Some(Int)
  3. otherwise recursively search for the item that is pointed to by the one we just found

further reading


You'll only receive email when they publish something new.

More from Teymour Aldridge