Skip to content

Commit

Permalink
Add kind-inference algorithm
Browse files Browse the repository at this point in the history
Previously kinds needed to be annotated on type synonyms, i.e.
```
type option_syn('a : Type) -> Type = option('a)
```
and on type constructors
```
struct S('a : Int, 'b : Type) = ...
```

This commit adds a kind-inference algorithm, so that these can become
```
type option_syn('a) = option('a)

struct S('a, 'b) = ...
```
  • Loading branch information
Alasdair committed Jul 23, 2024
1 parent 13d9458 commit 49b5852
Show file tree
Hide file tree
Showing 20 changed files with 1,080 additions and 325 deletions.
2 changes: 1 addition & 1 deletion src/lib/chunk_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ let string_of_kind (K_aux (k, _)) =
match k with K_type -> "Type" | K_int -> "Int" | K_order -> "Order" | K_bool -> "Bool"

(* Right now, let's just assume we never break up kinded-identifiers *)
let chunk_of_kopt (KOpt_aux (KOpt_kind (special, vars, kind), l)) =
let chunk_of_kopt (KOpt_aux (KOpt_kind (special, vars, kind, _), l)) =
match (special, kind) with
| Some c, Some k ->
Atom (Printf.sprintf "(%s %s : %s)" c (Util.string_of_list " " string_of_var vars) (string_of_kind k))
Expand Down
1,175 changes: 884 additions & 291 deletions src/lib/initial_check.ml

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/lib/parse_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ and atyp = ATyp_aux of atyp_aux * l

and kinded_id_aux =
(* optionally kind-annotated identifier *)
| KOpt_kind of string option * kid list * kind option (* kind-annotated variable *)
| KOpt_kind of string option * kid list * kind option * int option (* kind-annotated variable *)

and kinded_id = KOpt_aux of kinded_id_aux * l

Expand Down Expand Up @@ -399,7 +399,7 @@ type fundef_aux = (* Function definition *)

type type_def_aux =
(* Type definition body *)
| TD_abbrev of id * typquant * kind * atyp (* type abbreviation *)
| TD_abbrev of id * typquant * kind option * atyp (* type abbreviation *)
| TD_record of id * typquant * (atyp * id) list (* struct type definition *)
| TD_variant of id * typquant * type_union list (* union type definition *)
| TD_enum of id * (id * atyp) list * (id * exp option) list (* enumeration type definition *)
Expand Down
18 changes: 9 additions & 9 deletions src/lib/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@ kid_list:

kopt:
| Lparen Constant kid_list Colon kind Rparen
{ KOpt_aux (KOpt_kind (Some "constant", $3, Some $5), loc $startpos $endpos) }
{ KOpt_aux (KOpt_kind (Some "constant", $3, Some $5, None), loc $startpos $endpos) }
| Lparen kid_list Colon kind Rparen
{ KOpt_aux (KOpt_kind (None, $2, Some $4), loc $startpos $endpos) }
{ KOpt_aux (KOpt_kind (None, $2, Some $4, None), loc $startpos $endpos) }
| kid
{ KOpt_aux (KOpt_kind (None, [$1], None), loc $startpos $endpos) }
{ KOpt_aux (KOpt_kind (None, [$1], None, None), loc $startpos $endpos) }

kopt_list:
| kopt
Expand Down Expand Up @@ -989,9 +989,9 @@ r_def_body:

param_kopt:
| kid Colon kind
{ KOpt_aux (KOpt_kind (None, [$1], Some $3), loc $startpos $endpos) }
{ KOpt_aux (KOpt_kind (None, [$1], Some $3, None), loc $startpos $endpos) }
| kid
{ KOpt_aux (KOpt_kind (None, [$1], None), loc $startpos $endpos) }
{ KOpt_aux (KOpt_kind (None, [$1], None, None), loc $startpos $endpos) }

typaram:
| Lparen separated_nonempty_list(Comma, param_kopt) Rparen Comma typ
Expand All @@ -1002,13 +1002,13 @@ typaram:

type_def:
| Typedef id typaram Eq typ
{ mk_td (TD_abbrev ($2, $3, K_aux (K_type, Parse_ast.Unknown), $5)) $startpos $endpos }
{ mk_td (TD_abbrev ($2, $3, None, $5)) $startpos $endpos }
| Typedef id Eq typ
{ mk_td (TD_abbrev ($2, mk_typqn, K_aux (K_type, Parse_ast.Unknown), $4)) $startpos $endpos }
{ mk_td (TD_abbrev ($2, mk_typqn, None, $4)) $startpos $endpos }
| Typedef id typaram MinusGt kind Eq typ
{ mk_td (TD_abbrev ($2, $3, $5, $7)) $startpos $endpos }
{ mk_td (TD_abbrev ($2, $3, Some $5, $7)) $startpos $endpos }
| Typedef id Colon kind Eq typ
{ mk_td (TD_abbrev ($2, mk_typqn, $4, $6)) $startpos $endpos }
{ mk_td (TD_abbrev ($2, mk_typqn, Some $4, $6)) $startpos $endpos }
| Typedef id Colon kind
{ mk_td (TD_abstract ($2, $4)) $startpos $endpos }
| Struct id Eq Lcurly struct_fields Rcurly
Expand Down
30 changes: 30 additions & 0 deletions src/lib/util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,36 @@ module Option_monad = struct
let ( let+ ) = Option.map
end

module State_monad (S : sig
type t
end) =
struct
type 'a monad = S.t -> 'a * S.t

let ( let* ) state f env =
let y, env' = state env in
f y env'

let return x env = (x, env)

let fmap f m =
let* x = m in
return (f x)

let ( let+ ) = fmap

let rec mapM f = function
| [] -> return []
| x :: xs ->
let* y = f x in
let* ys = mapM f xs in
return (y :: ys)

let get_state s = (s, s)

let put_state s _ = ((), s)
end

module Duplicate (S : Set.S) = struct
type dups = No_dups of S.t | Has_dups of S.elt

Expand Down
19 changes: 19 additions & 0 deletions src/lib/util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ module Option_monad : sig
val ( let+ ) : ('a -> 'b) -> 'a option -> 'b option
end

module State_monad : functor
(S : sig
type t
end)
-> sig
type 'a monad = S.t -> 'a * S.t

val get_state : S.t monad
val put_state : S.t -> unit monad

val fmap : ('a -> 'b) -> 'a monad -> 'b monad
val return : 'a -> 'a monad

val ( let* ) : 'a monad -> ('a -> 'b monad) -> 'b monad
val ( let+ ) : ('a -> 'b) -> 'a monad -> 'b monad

val mapM : ('a -> 'b monad) -> 'a list -> 'b list monad
end

(** Mixed useful things *)
module Duplicate (S : Set.S) : sig
type dups = No_dups of S.t | Has_dups of S.elt
Expand Down
16 changes: 3 additions & 13 deletions test/typecheck/fail/synonym_rec.expect
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
Type error:
fail/synonym_rec.sail:2.0-10:
fail/synonym_rec.sail:2.9-10:
2 |type T = T
 |^--------^
 | Types are not well-formed within this type definition. Note that recursive types are forbidden.
 |
 | Caused by fail/synonym_rec.sail:2.9-10:
 | 2 |type T = T
 |  | ^
 |  | Well-formedness check failed for type
 |  |
 |  | Caused by fail/synonym_rec.sail:2.9-10:
 |  | 2 |type T = T
 |  |  | ^
 |  |  | Undefined type T
 | ^
 | Failed to infer kind for this type
9 changes: 9 additions & 0 deletions test/typecheck/pass/fn_kind_infer.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
default Order dec

$include <prelude.sail>

val f : forall 'b. bool('b) -> unit

val g : forall 'a. 'a -> 'a

val h : forall 'n 'a. vector('n, 'a) -> unit
16 changes: 16 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
default Order dec

$include <prelude.sail>

val f : forall 'b. int('b) -> unit

function f(b) = {
let 'a = 3;
if b == 3 then {
let _ : bool('a == 'b) = true;
}
}

function g forall 'b. (b: bool('b)) -> unit = {
()
}
8 changes: 8 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body/v1.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Type error:
pass/fn_kind_infer_body/v1.sail:14.18-28:
14 |function g forall ('b : Int). (b: bool('b)) -> unit = {
 | ^--------^ Inferred kind Int from this
pass/fn_kind_infer_body/v1.sail:14.39-41:
14 |function g forall ('b : Int). (b: bool('b)) -> unit = {
 | ^^
 | Expected this type to have kind Int but found kind Bool
16 changes: 16 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body/v1.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
default Order dec

$include <prelude.sail>

val f : forall 'b. int('b) -> unit

function f(b) = {
let 'a = 3;
if b == 3 then {
let _ : bool('a == 'b) = true;
}
}

function g forall ('b : Int). (b: bool('b)) -> unit = {
()
}
8 changes: 8 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body/v2.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Type error:
pass/fn_kind_infer_body/v2.sail:14.15-17:
14 |val g : forall 'b. bool('b) -> unit
 | ^^ 'b defined with kind Bool here
pass/fn_kind_infer_body/v2.sail:16.18-28:
16 |function g forall ('b : Int). (b: int('b)) -> unit = {
 | ^--------^
 | 'b defined here with kind Int in the function body, which is inconsistent with the function header
18 changes: 18 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body/v2.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
default Order dec

$include <prelude.sail>

val f : forall 'b. int('b) -> unit

function f(b) = {
let 'a = 3;
if b == 3 then {
let _ : bool('a == 'b) = true;
}
}

val g : forall 'b. bool('b) -> unit

function g forall ('b : Int). (b: int('b)) -> unit = {
()
}
8 changes: 8 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body/v3.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Type error:
pass/fn_kind_infer_body/v3.sail:14.4-5:
14 |val g : forall 'b. bool('b) -> unit
 | ^ declared here
pass/fn_kind_infer_body/v3.sail:16.11-43:
16 |function g forall 'b. (b: bool('b)) -> unit = {
 | ^------------------------------^
 | Duplicate quantifier between inline annotation and 'val' declaration
18 changes: 18 additions & 0 deletions test/typecheck/pass/fn_kind_infer_body/v3.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
default Order dec

$include <prelude.sail>

val f : forall 'b. int('b) -> unit

function f(b) = {
let 'a = 3;
if b == 3 then {
let _ : bool('a == 'b) = true;
}
}

val g : forall 'b. bool('b) -> unit

function g forall 'b. (b: bool('b)) -> unit = {
()
}
8 changes: 8 additions & 0 deletions test/typecheck/pass/struct_kind_infer.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
default Order dec

$include <prelude.sail>

struct S('a, 'n) = {
field1 : 'a,
field2 : bitvector('n),
}
5 changes: 5 additions & 0 deletions test/typecheck/pass/syn_kind_infer.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
default Order dec

$include <prelude.sail>

type xlen = 32
14 changes: 14 additions & 0 deletions test/typecheck/pass/union_infer_kind.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
default Order dec

$include <prelude.sail>

union my_option('a) = {
My_none : unit,
Ny_some : 'a,
}

union test_union('a, 'b, 'c) = {
A : list('a),
B : bool('b),
C : int('c),
}
2 changes: 1 addition & 1 deletion test/typecheck/pass/vector_subrange_gen.sail
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ overload __size = {bitvector_length}

default Order inc

val test : forall 'n 'm, 'n >= 5.
val test : forall 'n, 'n >= 5.
bitvector('n, inc) -> bitvector('n - 1, inc)

function test v = {
Expand Down
11 changes: 3 additions & 8 deletions test/typecheck/pass/wf_register_type/v1.expect
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
Type error:
pass/wf_register_type/v1.sail:7.13-24:
pass/wf_register_type/v1.sail:7.18-23:
7 |register r : bits(x / 2) = 0x0
 | ^---------^
 | Well-formedness check failed for type
 |
 | Caused by pass/wf_register_type/v1.sail:7.18-23:
 | 7 |register r : bits(x / 2) = 0x0
 |  | ^---^
 |  | Unknown type level operator or function (operator /)
 | ^---^
 | Unknown type level operator or function (operator /)

0 comments on commit 49b5852

Please sign in to comment.