Welcome to part 2 of this series in compiling functional languages to LLVM. In
part 1 we created a very simple
calculator that let us add, subtract and multiply integers like 1 + 1
or 6 * (5 - 2)
.
Today we’re going to spice things up a touch by adding some basic control flow. By the end of today we’re going to be writing sweet syntax such as:
2 + 2 == 5
if 6 == 6 then False else True
What’s new then?
To make our dreams come true, we’re going to two new syntactic features:
if
expressionsan
==
infix operator
Let’s do that now!
-- | operators for combining expressions
data Op
= OpAdd
| OpMultiply
| OpSubtract
| OpEquals -- this is new!
deriving stock (Eq, Ord, Show)
-- | Expressions, decorated with some unknown `ann`
data Expr ann
= EPrim ann Prim -- this `Prim` used to always be `Int`
| EInfix ann Op (Expr ann) (Expr ann)
| EIf ann (Expr ann) (Expr ann) (Expr ann) -- this is new!
deriving stock (Eq, Ord, Show, Functor, Foldable, Traversable)
Our EIf
constructor takes three Expr ann
as arguments. The first is the
predicate
, ie, the thing that must evaluate to True
or False
, and the
other two are expressions to be evaluated on the then
and else
branches.
For all of this to make any sense, these will need to be the same type.
Types?
Previously we sort of glossed over the idea of types, because every value in
our calculator was either an Integer
(ie, 1
, 42
) or an expression that would eventually
evaluate into an Integer
(like 1 + 1
, 6 * 12
).
However, the expression 1 == 1
doesn’t resolve to an Integer
, it can only
be True
or False
, ie a Boolean
type. (it is true that we could express
this with an Integer
but if we start cutting corners this early in the game
we’ll never get anywhere). This means we’ll need to extend our Prim
type to also describe Boolean
values as
well as Integer
s.
-- | types of basic values
data Prim
= PInt Integer
| PBool Bool
deriving stock (Eq, Ord, Show)
However this means we are in danger of our users being able to make silly mistakes like if 27 then False else 6
? How can we stop this? This can only mean one thing: we are going to need
to write a bidirectional type checker.
Bidirect what?
A bidirectional type checker is a way of working which types parts of an expression has, and identifying parts that don’t make sense. What makes it “bidirectional” is that it works in two “modes”:
infer
mode: given an expression, give me the typecheck
mode: given an expression and the type we think it has, give me the type
The broad idea is when we don’t know anything about an expression, we infer
what types it has, but as we learn more, we use that information to help us
work the rest out. Although we could arguably get away with just an infer
mode for a language
this simple, we will need this special two-way magic in future.
Enough waffle, let’s see some code, and then talk about it.
Code
Firstly, we need a type for types. We’ll call it Type
. We are adding an ann
type argument to it, so that we can attach source code locations etc. This will
be helpful for showing our user helpful errors, which we will be doing today:
data TypePrim = TBool | TInt
deriving stock (Eq, Ord, Show)
-- the `ann` is used to attach source code location etc
data Type ann
= TPrim ann TypePrim
deriving stock (Eq, Ord, Show, Functor)
Things go wrong
Any old typechecker can tell you when things are going well, but the ones that are really worth their salt are the ones that tell you helpfully what is going wrong. We will need these ones:
data TypeError ann
= PredicateIsNotBoolean ann (Type ann)
| InfixTypeMismatch Op [(Type ann, Type ann)]
| TypeMismatch (Type ann) (Type ann)
deriving stock (Eq, Ord, Show)
The meat of the thing
Therefore, the point of the typechecker is to take Expr ann
(ie, an Expr
carrying around some ann
that does not concern us), and turn it into either
TypeError ann
or Expr (Type ann)
.
Expr (Type ann)
means that we’ll have “decorated” each part of the Expr
with it’s type. We’ll take whatever ann
was in there and put it in the Type
instead. For example:
oneWhichIsAnInteger :: Expr (Type ())
= EPrim (TPrim () TInt) (PInt 1) oneWhichIsAnInteger
The entire typechecker lives here. Let’s go through the key parts:
-- | this is the function we run
elaborate :: Expr ann -> Either (TypeError ann) (Expr (Type ann))
= infer -- start with `infer` because we know nothing elaborate
elaborate
is the function the typechecker exports. It takes an untypechecked
Expr ann
and returns either Expr (Type ann)
or an excuse. It starts by
running infer
, which we’ll see shortly.
typeFromPrim :: ann -> Prim -> Type ann
PInt _) = TPrim ann TInt -- It's an Integer!
typeFromPrim ann (PBool _) = TPrim ann TBool -- It's a Boolean! typeFromPrim ann (
The most basic type inference we can do is looking at a primitive value. As it stands in our language, there is one number type and one boolean type, so we can unambiguously work out the type just by looking at the value.
inferIf :: ann -> Expr ann -> Expr ann -> Expr ann -> Either (TypeError ann) (Expr (Type ann))
= do
inferIf ann predExpr thenExpr elseExpr <- infer predExpr
predA case getOuterAnnotation predA of
TPrim _ TBool) -> pure ()
(-> throwError (PredicateIsNotBoolean ann otherType)
otherType <- infer thenExpr
thenA <- check (getOuterAnnotation thenA) elseExpr
elseA pure (EIf (getOuterAnnotation elseA) predA thenA elseA)
This is how if
works. We infer
the type of the predicate, then use
getOuterAnnotation
to get the Type ann
out of it. We then case match on it
to see if it’s a Boolean
or not, “throwing” an error if not.
inferInfix ::
->
ann Op ->
Expr ann ->
Expr ann ->
Either (TypeError ann) (Expr (Type ann))
OpEquals a b = do
inferInfix ann <- infer a
elabA <- infer b
elabB <- case (getOuterAnnotation elabA, getOuterAnnotation elabB) of
ty TPrim _ tA, TPrim _ tB)
(| tA == tB ->
-- if the types are the same, then great! it's a bool!
pure (TPrim ann TBool)
->
(otherA, otherB) -- otherwise, error!
TypeMismatch otherA otherB)
throwError (pure (EInfix ty OpEquals elabA elabB)
When typechecking ==
, we want to make sure both sides have the same type,
“throwing” an error if not.
= do
inferInfix ann op a b <- infer a
elabA <- infer b
elabB -- all the other infix operators need to be Int -> Int -> Int
<- case (getOuterAnnotation elabA, getOuterAnnotation elabB) of
ty TPrim _ TInt, TPrim _ TInt) ->
(-- if the types are the same, then great! it's an int!
pure (TPrim ann TInt)
TPrim _ TInt, other) ->
(
throwErrorInfixTypeMismatch
(
opTPrim (getOuterTypeAnnotation other) TInt,
[ (
other
)
]
)TPrim _ TInt) ->
(other,
throwErrorInfixTypeMismatch
(
opTPrim (getOuterTypeAnnotation other) TInt,
[ (
other
)
]
)->
(otherA, otherB) -- otherwise, error!
throwErrorInfixTypeMismatch
(
opTPrim (getOuterTypeAnnotation otherA) TInt, otherA),
[ (TPrim (getOuterTypeAnnotation otherB) TInt, otherB)
(
]
)pure (EInfix ty op elabA elabB)
Here are the other operators. Both the arguments should be Integer
and the
return type is Integer
, otherwise we construct and return an error type. It
seems like a lot of work to be so specific, but look how helpful our errors
are!
You can see all the error rendering code here.
infer :: Expr ann -> Either (TypeError ann) (Expr (Type ann))
EPrim ann prim) =
infer (pure (EPrim (typeFromPrim ann prim) prim)
EIf ann predExpr thenExpr elseExpr) =
infer (
inferIf ann predExpr thenExpr elseExprEInfix ann op a b) =
infer ( inferInfix ann op a b
That’s how we put infer
together, easy!
check :: Type ann -> Expr ann -> Either (TypeError ann) (Expr (Type ann))
= do
check ty expr <- infer expr
exprA if void (getOuterAnnotation exprA) == void ty
then pure (expr $> ty)
else throwError (TypeMismatch ty (getOuterAnnotation exprA))
Lastly, here’s check
. We only use it when comparing arms of if
statements,
but soon this will become more interesting.
You can see all of the typechecker code here.
Interpreting our new friends
Before heading back into LLVM land, let’s update our manual interpreter so we can understand what’s needed here.
Firstly, it’s now possible that our interpreter can fail. This will only happen if our typechecker is not working as expected, but we should make a proper error type for it anyway because we are good programmers who care about our users.
data InterpreterError ann
= NonBooleanPredicate ann (Expr ann)
deriving stock (Eq, Ord, Show)
Interpreting infix expressions is a little bit more complicated, as our pattern
matches have to make sure we’re looking at the right Prim
values. The
eagle-eyed may notice that a broken typechecker could send this into a loop.
Can you see where?
interpretInfix ::
MonadError (InterpreterError ann) m) =>
(->
ann Op ->
Expr ann ->
Expr ann ->
Expr ann)
m (OpAdd (EPrim _ (PInt a)) (EPrim _ (PInt b)) =
interpretInfix ann pure $ EPrim ann (PInt $ a + b)
OpSubtract (EPrim _ (PInt a)) (EPrim _ (PInt b)) =
interpretInfix ann pure $ EPrim ann (PInt $ a - b)
OpMultiply (EPrim _ (PInt a)) (EPrim _ (PInt b)) =
interpretInfix ann pure $ EPrim ann (PInt $ a * b)
OpEquals (EPrim _ a) (EPrim _ b) =
interpretInfix ann pure $ EPrim ann (PBool $ a == b)
= do
interpretInfix ann op a b <- interpret a
iA <- interpret b
iB interpretInfix ann op iA iB
We ended up with a MonadError
constraint above - why’s that? It’s because the
main interpret
function can now “explode” if we try and match a non-predicate
in an if statement. Our typechecker should stop this happening of course.
-- | just keep reducing the thing until the smallest thing
interpret ::
MonadError (InterpreterError ann) m
( =>
) Expr ann ->
Expr ann)
m (EPrim ann p) = pure (EPrim ann p)
interpret (EInfix ann op a b) =
interpret (
interpretInfix ann op a bEIf ann predExpr thenExpr elseExpr) = do
interpret (<- interpret predExpr
predA case predA of
EPrim _ (PBool True)) -> interpret thenExpr
(EPrim _ (PBool False)) -> interpret elseExpr
(-> throwError (NonBooleanPredicate ann other) other
We interpret if statements by reducing the predicate down to a boolean, then taking a peek, and then interpreting the appropriate branch. If we don’t need a branch, there’s no need to interpret it!
You can see all of the interpreter code here.
OK, LLVM time
I feel like I’m rushing through all this, and maybe copy pasta-ing an entire typechecker in the preamble was somewhat undisclined of me.
BUT, here we go.
Digression
Firstly, we’ll add a new function to our C “standard library”:
void printbool(int b) {
(b ? "True" : "False");
printf}
It will take an LLVM boolean, and print either True
or False
depending on
whether it is 0
or not.
To the IR!
We’re going to start by looking at the LLVM IR for the following arbitrary expression:
if 2 == 1 then True else False
; ModuleID = 'example'
declare external ccc void @printbool(i1)
define external ccc i32 @main() {
%1 = icmp eq i32 2, 1
%2 = alloca i1
br i1 %1, label %then_0, label %else_0
then_0:
store i1 1, i1* %2
br label %done_0
else_0:
store i1 0, i1* %2
br label %done_0
done_0:
%3 = load i1, i1* %2
call ccc void @printbool(i1 %3)
ret i32 0
}
What a ride! Let’s take it line by line.
; ModuleID = 'example'
Once again, let’s ease ourselves in with a code comment.
declare external ccc void @printbool(i1)
Declaration for the new function in our standard library. It takes an i1
(a
boolean, stored as 0
or 1
) and returns void
.
define external ccc i32 @main() {
We define the main
function, which is the entry point of our program. It
takes no arguments, and returns an i32
integer value (which becomes the exit
code).
%1 = icmp eq i32 2, 1
Here we are making a new variable, %1
, by comparing two integers, 2
and
1
, using eq
. This is our 2 == 1
expression, and maps across quite neatly.
%2 = alloca i1
To make control flow works, we are going to need to jump to different places.
However, LLVM has no way of passing a value back between sections. Therefore,
we are going to create a mutable placeholder for the result, and each branch
will be responsible for storing the result here. alloca
is broadly “allocate
memory” and i1
is the LLVM type for a Boolean
.
br i1 %1, label %then_0, label %else_0
This is where we do the branching. br
takes an i1
value for the predicate,
and then two labels for blocks that we’ll jump to depending on the value of the predicate.
Therefore if %1
is 1
we’ll jump to then_0
, otherwise we’ll jump to
else_0
. We’ll define these shortly.
then_0:
store i1 1, i1* %2
br label %done_0
This defines a block labelled then_0
. We will “jump” here in the “then” case
of the if statement. We store 1
in the %2
variable, and then jump to the
done_0
block.
else_0:
store i1 0, i1* %2
br label %done_0
This defines a block labelled else_0
. We will “jump” here in the “else” case.
Once again, we store 0
in the %2
variable, and then jump to done_0
.
done_0:
%3 = load i1, i1* %2
This introduces a new block called done_0
. As our if construct is an
expression, we always need to return something, so ee jump here when the
then
or expr
branches are finishing doing their business, and load whatever
they stored in %2
.
call ccc void @printbool(i1 %3)
Call the printbool
function from our standard library with the loaded value.
ret i32 0
As our program succeeded, we return a 0
, this becomes our exit code.
}
As a little palette cleanser, a nice closing brace.
Generating it from Haskell
Now we have Boolean
as well as Integer
values, we’ll need to represent them
in LLVM. We’ll use a bit
which is a 1-bit LLVM number to represent
Boolean
s.
primToLLVM :: Prim -> LLVM.Operand
PInt i) = LLVM.int32 (fromIntegral i)
primToLLVM (PBool True) = LLVM.bit 1
primToLLVM (PBool False) = LLVM.bit 0 primToLLVM (
Now we’ll need to choose the right printing function:
-- import the correct output function from our standard library
-- depending on the output type of our expression
printFunction :: (LLVM.MonadModuleBuilder m) => Type ann -> m LLVM.Operand
TPrim _ TInt) = LLVM.extern "printint" [LLVM.i32] LLVM.void
printFunction (TPrim _ TBool) = LLVM.extern "printbool" [LLVM.i1] LLVM.void printFunction (
The most interesting part is if
expressions. We use the RecursiveDo
extension, which gives us the mdo
syntax. This lets us use bindings before
they are created. This will allow us to use thenBlock
and elseBlock
before
they’re defined. We create IR for the predExpr
, then pass it to
LLVM.condBr
, which will then jump to the appropriate block depending on the
value.
ifToLLVM ::
LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) =>
(Type ann ->
Expr (Type ann) ->
Expr (Type ann) ->
Expr (Type ann) ->
LLVM.Operand
m = mdo
ifToLLVM tyReturn predExpr thenExpr elseExpr -- create IR for predicate
<- exprToLLVM predExpr
irPred
-- make variable for return value
<- LLVM.alloca (typeToLLVM tyReturn) Nothing 0
irReturnValue
-- this does the switching
-- we haven't created these blocks yet but RecursiveDo lets us do this with
-- MonadFix magic
LLVM.condBr irPred thenBlock elseBlock
-- create a block for the 'then` branch
<- LLVM.block `LLVM.named` "then"
thenBlock -- create ir for the then branch
<- exprToLLVM thenExpr
irThen -- store the result in irResultValue
0 irThen
LLVM.store irReturnValue -- branch back to the 'done' block
LLVM.br doneBlock
-- create a block for the 'else' branch
<- LLVM.block `LLVM.named` "else"
elseBlock -- create ir for the else branch
<- exprToLLVM elseExpr
irElse -- store the result in irReturnValue
0 irElse
LLVM.store irReturnValue -- branch back to the `done` block
LLVM.br doneBlock
-- create a block for 'done' that we always branch to
<- LLVM.block `LLVM.named` "done"
doneBlock -- load the result and return it
0 LLVM.load irReturnValue
To work out which kind of type to alloca
, we take the return type and use it
to work which LLVM type to use.
typeToLLVM :: Type ann -> LLVM.Type
TPrim _ TBool) = LLVM.i1
typeToLLVM (TPrim _ TInt) = LLVM.i32 typeToLLVM (
You can see all of the LLVM conversion code here.
Well that’s that
Congratulations, you are all bidirectional type checking experts now. Hopefully that was somewhat helpful. Next time we’ll be adding basic functions and variables. Great!
Make sense? If not, get in touch!
Further reading: