Skip to content

Commit

Permalink
fix: Don't optimize away expressions which could possibly contain sid…
Browse files Browse the repository at this point in the history
…e effects (#523)
  • Loading branch information
ospencer committed Feb 4, 2021
1 parent 111b549 commit acc7d65
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 384 deletions.
216 changes: 79 additions & 137 deletions compiler/src/middle_end/analyze_purity.re
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@ open Grain_typed;
*/

type analysis +=
| Pure(bool)
| PurityTable(Ident.tbl(bool));

let purity_tbl: ref(Ident.tbl(bool)) = (
ref(Ident.empty): ref(Ident.tbl(bool))
);
| Pure(bool);

let rec get_purity = lst =>
switch (lst) {
Expand All @@ -23,137 +18,112 @@ let rec get_purity = lst =>
| [_, ...tl] => get_purity(tl)
};

let rec get_purity_tbl = lst =>
switch (lst) {
| [] => raise(Not_found)
| [PurityTable(t), ..._] => t
| [_, ...tl] => get_purity_tbl(tl)
};

let set_id_purity = (id, p) => purity_tbl := Ident.add(id, p, purity_tbl^);

module StringHash =
Hashtbl.Make({
type t = string;
let hash = x => Hashtbl.hash(x);
let equal = (a, b) => String.compare(a, b) === 0;
});

let pervasives_purity =
StringHash.of_seq(
List.to_seq([
("+", true),
("-", true),
("*", true),
("==", true),
("<", true),
(">", true),
("<=", true),
(">=", true),
("&&", true),
("||", true),
("void", true),
("[...]", true),
("[]", true),
]),
);

let get_id_purity = (id: Ident.t): bool =>
try(StringHash.find(pervasives_purity, Ident.name(id))) {
| Not_found => Ident.find_same(id, purity_tbl^)
};

let imm_expression_purity = ({imm_analyses}) => get_purity(imm_analyses^);
let comp_expression_purity = ({comp_analyses}) =>
get_purity(comp_analyses^);
let anf_expression_purity = ({anf_analyses}) => get_purity(anf_analyses^);

let pure_identifiers = ({analyses}) => get_purity_tbl(analyses^);

/* Quick accessors for known-existing values */
let imm_expression_purity_internal = i =>
Option.get(imm_expression_purity(i));
let comp_expression_purity_internal = c =>
Option.get(comp_expression_purity(c));
let anf_expression_purity_internal = a =>
Option.get(anf_expression_purity(a));

let push_purity = (lref, p) => lref := [Pure(p), ...lref^];

let analyze_imm_expression = ({imm_desc, imm_analyses}) =>
switch (imm_desc) {
| ImmId(id) => push_purity(imm_analyses, get_id_purity(id))
| ImmConst(_) => push_purity(imm_analyses, true)
};

let rec analyze_comp_expression =
({comp_desc: desc, comp_analyses: analyses}) => {
let purity =
switch (desc) {
| CImmExpr(i) =>
analyze_imm_expression(i);
imm_expression_purity_internal(i);
| CPrim1(Box, _)
| CPrim1(Unbox, _) => false
| CPrim1(_, a) =>
analyze_imm_expression(a);
imm_expression_purity_internal(a);
| CPrim2(_, a1, a2) =>
analyze_imm_expression(a1);
analyze_imm_expression(a2);
imm_expression_purity_internal(a1)
&& imm_expression_purity_internal(a2);
| CPrimN(_, args) =>
List.iter(arg => analyze_imm_expression(arg), args);
false;
| CBoxAssign(_, _)
| CAssign(_, _) =>
| CImmExpr(_) => true
| CPrim1(
Incr | Decr | Not | Box | Unbox | Ignore | ArrayLength |
Int64FromNumber |
Int64ToNumber |
Int32ToNumber |
Float64ToNumber |
Float32ToNumber |
Int64Lnot |
WasmFromGrain |
WasmToGrain |
WasmUnaryI32(_) |
WasmUnaryI64(_) |
WasmUnaryF32(_) |
WasmUnaryF64(_),
_,
) =>
true
| CPrim1(Assert | FailWith, _) => false
| CPrim2(
Plus | Minus | Times | Divide | Mod | Less | Greater | LessEq |
GreaterEq |
Is |
Eq |
And |
Or |
ArrayMake |
Int64Land |
Int64Lor |
Int64Lxor |
Int64Lsl |
Int64Lsr |
Int64Asr |
Int64Gt |
Int64Gte |
Int64Lt |
Int64Lte |
WasmLoadI32(_) |
WasmLoadI64(_) |
WasmLoadF32 |
WasmLoadF64 |
WasmBinaryI32(_) |
WasmBinaryI64(_) |
WasmBinaryF32(_) |
WasmBinaryF64(_),
_,
_,
) =>
true
| CPrim2(ArrayInit, _, _) => false // Array.init takes (and calls) a function which could have side effects
| CPrimN(
WasmStoreI32(_) | WasmStoreI64(_) | WasmStoreF32 | WasmStoreF64 |
WasmMemoryCopy |
WasmMemoryFill,
_,
) =>
false
| CArraySet(_)
| CBoxAssign(_)
| CAssign(_) =>
/* TODO: Would be nice if we could "scope" the purity analysis to local assignments */
false
| CArrayGet(a, i) =>
analyze_imm_expression(a);
analyze_imm_expression(i);
imm_expression_purity_internal(a) && imm_expression_purity_internal(i);
| CArrayGet(_)
| CArray(_)
| CArraySet(_) => false
| CTuple(args)
| CAdt(_, _, args) =>
let arg_purities =
List.map(
arg => {
analyze_imm_expression(arg);
imm_expression_purity_internal(arg);
},
args,
);
List.for_all(x => x, arg_purities);
| CRecord(_) => false
| CGetTupleItem(_, a) =>
analyze_imm_expression(a);
imm_expression_purity_internal(a);
| CTuple(_)
| CAdt(_)
| CRecord(_)
| CGetTupleItem(_) => true
| CSetTupleItem(_) => false
| CGetAdtItem(_, a) =>
analyze_imm_expression(a);
imm_expression_purity_internal(a);
| CGetAdtTag(_) => true
| CGetRecordItem(_, r) =>
analyze_imm_expression(r);
imm_expression_purity_internal(r);
| CGetAdtItem(_)
| CGetAdtTag(_)
| CGetRecordItem(_) => true
| CSetRecordItem(_) => false
| CIf(c, t, f) =>
analyze_imm_expression(c);
| CIf(_, t, f) =>
analyze_anf_expression(t);
analyze_anf_expression(f);
imm_expression_purity_internal(c)
&& anf_expression_purity_internal(t)
&& anf_expression_purity_internal(f);
anf_expression_purity_internal(t) && anf_expression_purity_internal(f);
| CWhile(c, body) =>
analyze_anf_expression(c);
analyze_anf_expression(body);
anf_expression_purity_internal(c)
&& anf_expression_purity_internal(body);
| CSwitch(exp, branches) =>
analyze_imm_expression(exp);
| CSwitch(_, branches) =>
let branches_purities =
List.map(
((t, b)) => {
Expand All @@ -162,26 +132,12 @@ let rec analyze_comp_expression =
},
branches,
);
imm_expression_purity_internal(exp)
&& List.for_all(x => x, branches_purities);
| CApp((f, _), args, _) =>
let arg_purities =
List.map(
arg => {
analyze_imm_expression(arg);
imm_expression_purity_internal(arg);
},
args,
);
analyze_imm_expression(f);
imm_expression_purity_internal(f) && List.for_all(x => x, arg_purities);
| CAppBuiltin(_module, f, args) =>
List.iter(arg => analyze_imm_expression(arg), args);
false;
| CLambda(args, (body, _)) =>
List.iter(((i, _)) => set_id_purity(i, true), args);
List.for_all(x => x, branches_purities);
| CApp(_) => false
| CAppBuiltin(_) => false
| CLambda(_, (body, _)) =>
analyze_anf_expression(body);
anf_expression_purity_internal(body);
true;
| CNumber(_)
| CInt32(_)
| CInt64(_)
Expand All @@ -199,24 +155,17 @@ and analyze_anf_expression = ({anf_desc: desc, anf_analyses: analyses}) =>
| AELet(g, Nonrecursive, binds, body) =>
let process_bind = ((id, bind)) => {
analyze_comp_expression(bind);
let purity = comp_expression_purity_internal(bind);
set_id_purity(id, purity);
purity;
comp_expression_purity_internal(bind);
};

let bind_purity = List.for_all(x => x, List.map(process_bind, binds));
analyze_anf_expression(body);
let purity = anf_expression_purity_internal(body) && bind_purity;
push_purity(analyses, purity);
| AELet(g, Recursive, binds, body) =>
/* Initialize purity to true, just so they're in scope */
List.iter(((id, _)) => set_id_purity(id, true), binds);
/* Do the actual purity analysis */
let process_bind = ((id, bind)) => {
analyze_comp_expression(bind);
let purity = comp_expression_purity_internal(bind);
set_id_purity(id, purity);
purity;
comp_expression_purity_internal(bind);
};

let bind_purity = List.for_all(x => x, List.map(process_bind, binds));
Expand All @@ -237,12 +186,5 @@ and analyze_anf_expression = ({anf_desc: desc, anf_analyses: analyses}) =>
};

let analyze = ({imports, body, analyses}) => {
purity_tbl := Ident.empty;
let process_import = ({imp_use_id}) =>
/* TODO: pure imports */
set_id_purity(imp_use_id, false);

List.iter(process_import, imports);
analyze_anf_expression(body);
analyses := [PurityTable(purity_tbl^), ...analyses^];
};
3 changes: 0 additions & 3 deletions compiler/src/middle_end/analyze_purity.rei
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
open Anftree;
open Grain_typed;

let imm_expression_purity: imm_expression => option(bool);
let comp_expression_purity: comp_expression => option(bool);
let anf_expression_purity: anf_expression => option(bool);

let pure_identifiers: anf_program => Ident.tbl(bool);

let analyze: Analysis_pass.t;
1 change: 0 additions & 1 deletion compiler/src/middle_end/optimize.re
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ let optimization_passes = [
Optimize_tail_calls.optimize,
Optimize_constants.optimize,
Optimize_simple_binops.optimize,
Optimize_common_subexpressions.optimize,
Optimize_dead_assignments.optimize,
Optimize_dead_branches.optimize,
Optimize_inline_wasm.optimize,
Expand Down
Loading

0 comments on commit acc7d65

Please sign in to comment.