Polymorphic variant inference that beats OCaml

29 May 2023

You may remember that a few months ago, I spent quite some time implementing type inference for polymorphic variant patterns in Polaris. In that post, I settled on roughly the same approach as OCaml. As it turns out, we can do much better!

A quick recap

The issue with inference for these kinds of patterns is that sometimes, polymorphic variant patterns should be inferred to a closed variant type and sometimes to an open type

# This should have type f : < A, B > -> Bool
let f(x) = match x {
    A -> true
    B -> false
}

# This should have type g : forall r. < A, B | r > -> Bool
let g(x) = match x {
    A -> true
    B -> false
    _ -> false
}

The solution taken by both Polaris and OCaml is to always infer an open variant type and then later check the patterns for exhaustiveness and close the type when appropriate (by setting the row variable to the empty variant <>).

This is not enough

Consider this function (This may seem a bit artificial, but I ran into something similar in a real Polaris script!)

let f(x) = match x {
    A -> B
    y -> y
}

What is the type of y? Both OCaml and current Polaris will infer < A, B | ?r >, and therefore infer the full type f : forall r. < A, B | r > -> < A, B | r >. But if you look closely, you will notice that this is not actually correct! y can never be A, because in that case, the pattern above it will match.

It gets a bit more complicated still. If we change the function slightly,

let f(x) = match x {
    (A, 5) -> B
    (y, _) -> y
}

the first pattern may not match, even if the first component is A, so y can be A now!

Fortunately, there is a relatively simple rule to figure out when to refine variants. We know that all further patterns cannot contain a variant, whenever a pattern is otherwise irrefutable (i.e. will always match). For example, ((A, x), _) is irrefutable apart from the variant pattern A, but (A, 5) is not, because 5 is not irrefutable.

We can then infer the correct type for the function above (forall r. < A, B | r > -> < B | r >), by refining the variant for the remaining patterns.

But how do you implement refinement?

This is a little harder for Polaris than it would be for OCaml, since Polaris’ type system is constraint based. This means we usually cannot match on types directly during type inference as they may still contain unification variables that will later be substituted by the constraint solver.

Right now, match expressions are inferred like this:

To refine variant types, we need to take a slightly different approach:

And that’s it! For only a bit more effort, this should infer polymorphic variants much more precisely than OCaml!