Skip to content

Commit

Permalink
Basic match for integers; bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
AdUhTkJm committed Jan 24, 2025
1 parent ed167bd commit 269eb1b
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 50 deletions.
198 changes: 152 additions & 46 deletions src/riscv_generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ let deal_with_prim tac rd (prim: Primitive.prim) args =
(* Load from the offset plus 4 for the tag *)
Vec.push tac (Load { rd; rs = arg; offset = offset + 4; byte = size })

| Pignore -> ()
| Pignore ->
Vec.push tac (Assign { rd; rs = unit })

(* Calculates whether two references are equal; gets a boolean value *)
| Prefeq ->
Expand Down Expand Up @@ -710,9 +711,9 @@ let rec do_convert tac (expr: Mcore.expr) =

(* If this is a `Join`, then we must jump to the corresponding letfn *)
if kind = Join then (
Vec.push tac (Jump !current_join);
Vec.push tac (Assign { rd = !current_join_ret; rs = rd });
unit
Vec.push tac (Jump !current_join);
!current_join_ret
) else (
Vec.append tac after;
rd
Expand Down Expand Up @@ -1020,10 +1021,10 @@ let rec do_convert tac (expr: Mcore.expr) =

current_join := join;
current_join_ret := rd;

let ret = do_convert tac afterwards in

(* This is definitely a unit *)
let _ = do_convert tac afterwards in

Vec.push tac (Assign { rd; rs = ret });
Vec.push tac (Jump join);
Vec.push tac (Label join);

Expand Down Expand Up @@ -1072,68 +1073,82 @@ let rec do_convert tac (expr: Mcore.expr) =
let index = new_temp Mtype.T_int in
Vec.push tac (Load { rd = index; rs = obj; offset = 0; byte = 4 });

let tag_offsets = Hashtbl.find variants (nameof obj.ty) in

(* Generate a jump table *)
let label = new_label "jumptable_" in
let jumps = List.init (List.length cases) (fun _ -> new_label "jumptable_") in
let jumps = List.init (List.length tag_offsets) (fun _ -> new_label "jumptable_") in
let out = new_label "jumptable_out_" in
Vec.push global_inst (ExtArray { label; values = jumps; elem_size = 8 });
let default_lbl = new_label "jumptable_default_" in

(* Choose which place to jump to *)
let jtable = new_temp Mtype.T_bytes in
let ptr_sz = new_temp Mtype.T_int in
let off = new_temp Mtype.T_int in
let place = new_temp Mtype.T_bytes in
let off = new_temp Mtype.T_bytes in
let altered = new_temp Mtype.T_bytes in
let target = new_temp Mtype.T_bytes in

(* Assign all these different possibilities into rd *)
let rd = new_temp ty in

(* Load the address *)
Vec.push tac (AssignLabel { rd = jtable; imm = label });
Vec.push tac (AssignInt { rd = ptr_sz; imm = pointer_size });
Vec.push tac (Mul { rd = off; rs1 = index; rs2 = ptr_sz });
Vec.push tac (Add { rd = place; rs1 = jtable; rs2 = off });
Vec.push tac (Load { rd = target; rs = place; offset = 0; byte = pointer_size });
(* Jump to that address *)
Vec.push tac (JumpIndirect { rs = target; possibilities = jumps });
Vec.push tac (Add { rd = altered; rs1 = jtable; rs2 = off });
Vec.push tac (Load { rd = target; rs = altered; offset = 0; byte = pointer_size });

let tag_offsets = Hashtbl.find variants (nameof obj.ty) in
let returns = Vec.empty () in
let visited = Vec.empty () in
let correspondence = Array.make (List.length tag_offsets) "_uninit" in
(* For each label, generate the code for each label *)

(* For each label, generate the code of it *)
let tac_cases = Vec.empty () in

List.iter (fun ((tag: Tag.t), ident, expr) ->
let lbl = List.nth jumps tag.index in

Vec.push tac (Label lbl);
Vec.push tac_cases (Label lbl);
(match ident with
| None -> ()
| Some x ->
Vec.push tac (Assign { rd = { name = Ident.to_string x; ty = obj.ty }; rs = obj }));
let ret = do_convert tac expr in
Vec.push tac (Jump out);
Vec.push returns (ret, lbl);
Vec.push tac_cases (Assign { rd = { name = Ident.to_string x; ty = obj.ty }; rs = obj }));
let ret = do_convert tac_cases expr in
Vec.push tac_cases (Assign { rd; rs = ret });
Vec.push tac_cases (Jump out);
Vec.push visited tag.index;
correspondence.(tag.index) <- lbl
) cases;

(match default with
| None -> ()
| Some x ->
let default_lbl = new_label "jumptable_default_" in
let visited = visited |> Vec.to_list in

Vec.push tac (Label default_lbl);
let ret = do_convert tac expr in
Vec.push tac (Jump out);
Vec.push tac_cases (Label default_lbl);
let ret = do_convert tac_cases x in
Vec.push tac_cases (Assign { rd; rs = ret });
Vec.push tac_cases (Jump out);

List.iteri (fun i x ->
if not (List.mem i visited) then (
Vec.push returns (ret, default_lbl);
correspondence.(i) <- default_lbl
)
) tag_offsets);
) tag_offsets;);

(* Now assign all these different things into rd *)
let rd = new_temp ty in
Vec.push tac_cases (Label out);

(* Deduplicate all possible targets *)
let possibilities =
Array.to_list correspondence |> Stringset.of_list |> Stringset.to_seq |> List.of_seq
in
(* Jump to the correct target *)
Vec.push tac (JumpIndirect { rs = target; possibilities });

(* Emit all match cases *)
Vec.append tac tac_cases;

Vec.push tac (Label out);
Vec.push tac (Phi { rd; rs = Vec.to_list returns });
(* Record the correct label order *)
Vec.push global_inst (ExtArray { label; values = Array.to_list correspondence; elem_size = 8 });
rd

| Cexpr_letrec _ ->
Expand All @@ -1145,7 +1160,7 @@ let rec do_convert tac (expr: Mcore.expr) =
unit

| Cexpr_switch_constant { obj; cases; default; ty; _ } ->
let obj = do_convert tac obj in
let index = do_convert tac obj in

let die () =
failwith "riscv_generate.ml: bad match on constants"
Expand All @@ -1170,20 +1185,106 @@ let rec do_convert tac (expr: Mcore.expr) =
) cases
in

let mx = List.fold_left (fun mx x -> max mx x) 0 values in
let mn = List.fold_left (fun mn x -> min mx x) 0 values in

let value = new_temp Mtype.T_int in
let mx = List.fold_left (fun mx x -> max mx x) (-2147483647-1) values in
let mn = List.fold_left (fun mn x -> min mn x) 2147483647 values in

(* Sparse values, generate a hash function *)
if mx - mn >= 10 then (
failwith "TODO: large "
) else
Vec.push tac (Assign { rd = value; rs = obj });

(* Compile into jump table *)
failwith "TODO: jump table";
()
if mx - mn >= 20 then (
failwith "TODO: large"
)

(* Dense values, just get a jump table *)
else (
let table = new_label "jumptable_int_" in
let jump = new_label "do_jump_int_" in
let jumps = List.init (mx - mn + 1) (fun _ -> new_label "jumptable_int_") in
let out = new_label "jumptable_int_out_" in
let default_lbl = new_label "jumptable_default_" in

(* If the value is outside the min/max range, jump to default *)
let inrange = new_temp Mtype.T_bool in
let maximum = new_temp Mtype.T_int in
let minimum = new_temp Mtype.T_int in
let _1 = new_temp Mtype.T_bool in
let _2 = new_temp Mtype.T_bool in

(* Evaluate (x < max) && (x > min), which is the range where we can use jump table *)
Vec.push tac (AssignInt { rd = maximum; imm = mx });
Vec.push tac (AssignInt { rd = minimum; imm = mn });
Vec.push tac (Leq { rd = _1; rs1 = index; rs2 = maximum });
Vec.push tac (Geq { rd = _2; rs1 = index; rs2 = minimum });
Vec.push tac (And { rd = inrange; rs1 = _1; rs2 = _2 });
Vec.push tac (Branch { cond = inrange; ifso = jump; ifnot = default_lbl });

(* Load the address *)
Vec.push tac (Label jump);

let jtable = new_temp Mtype.T_bytes in
let ptr_sz = new_temp Mtype.T_int in
let off = new_temp Mtype.T_int in
let altered = new_temp Mtype.T_bytes in
let target = new_temp Mtype.T_bytes in

Vec.push tac (AssignLabel { rd = jtable; imm = table });
Vec.push tac (AssignInt { rd = ptr_sz; imm = pointer_size });

(* We must also minus the minimum, unlike switch_constr *)
let min_var = new_temp Mtype.T_int in
let ind_2 = new_temp Mtype.T_int in

Vec.push tac (AssignInt { rd = min_var; imm = mn });
Vec.push tac (Sub { rd = ind_2; rs1 = index; rs2 = min_var });

(* Now find which address to jump to *)
Vec.push tac (Mul { rd = off; rs1 = ind_2; rs2 = ptr_sz });
Vec.push tac (Add { rd = altered; rs1 = jtable; rs2 = off });
Vec.push tac (Load { rd = target; rs = altered; offset = 0; byte = pointer_size });

let visited = Vec.empty () in
let correspondence = Array.make (List.length cases) "_uninit" in

(* For each label, generate the code of it *)
let tac_cases = Vec.empty () in

List.iter2 (fun value (_, expr) ->
let lbl = List.nth jumps (value - mn) in

Vec.push tac_cases (Label lbl);
let ret = do_convert tac_cases expr in
Vec.push tac_cases (Assign { rd; rs = ret });
Vec.push tac_cases (Jump out);
Vec.push visited value;
correspondence.(value - mn) <- lbl
) values cases;

(* For each values in the (min, max) range, redirect them into default *)
let visited = visited |> Vec.to_list in

Vec.push tac_cases (Label default_lbl);
let ret = do_convert tac_cases default in
Vec.push tac_cases (Assign { rd; rs = ret });
Vec.push tac_cases (Jump out);

List.iter (fun i ->
if not (List.mem i visited) then (
correspondence.(i - mn) <- default_lbl
)
) (List.init (mx - mn) (fun i -> i + mn));

(* Store the correct order of jump table *)
Vec.push tac_cases (Label out);
Vec.push global_inst (ExtArray
{ label = table; values = Array.to_list correspondence; elem_size = 8 });

(* Deduplicate possibilities and jump there *)
let possibilities =
Array.to_list correspondence |> Stringset.of_list |> Stringset.to_seq |> List.of_seq
in

Vec.push tac (JumpIndirect { rs = target; possibilities });
Vec.append tac tac_cases;

)

| _ -> failwith "TODO: unsupported switch constant type");

Expand Down Expand Up @@ -1275,6 +1376,7 @@ let rec do_convert tac (expr: Mcore.expr) =
rd

| Cexpr_function _ ->
Printf.printf "unconverted: %s\n" (Mcore.sexp_of_expr expr |> S.to_string);
failwith "riscv_generate.ml: Cexpr_function should have been converted into letfn"

let generate_vtables () =
Expand Down Expand Up @@ -1455,6 +1557,10 @@ let analyze_closure (top: Mcore.top_item) =
iter_expr find_closures func.body;
Vec.iter process_closure worklist

| Ctop_expr { expr } ->
iter_expr find_closures expr;
Vec.iter process_closure worklist

| _ -> ()

let convert_toplevel _start (top: Mcore.top_item) =
Expand Down Expand Up @@ -1548,7 +1654,7 @@ let ssa_of_mcore (core: Mcore.t) =
(* Deal with main *)
let with_main = match core.main with
| Some (main_expr, _) ->
let lambda_removed = convert_lambda main_expr in
let lambda_removed = map_expr convert_lambda main_expr in

(* Find closures in main *)
let closures = Vec.empty () in
Expand Down
3 changes: 3 additions & 0 deletions src/riscv_opt_gather.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ open Riscv_opt
open Riscv_ssa

let opt tac =
let out_noopt = Printf.sprintf "%s-no-opt.ssa" !Driver_config.Linkcore_Opt.output_file in
Basic_io.write out_noopt (String.concat "\n" (List.map Riscv_ssa.to_string tac));

List.iter (fun top -> match top with
| FnDecl { fn; args } -> Hashtbl.add params fn args
| _ -> ()) tac;
Expand Down
8 changes: 8 additions & 0 deletions src/riscv_tac2ssa.ml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ let output_idom idom =
Printf.printf "idom(%s) = %s\n" x y;
) idom

let output_frontier frontier =
print_endline "Frontiers:";
Hashtbl.iter (fun x y ->
Printf.printf "%s: \n" x;
Stringset.iter (fun z -> Printf.printf "%s " z) y;
Printf.printf "\n\n"
) frontier

(**
Calculate dominator.
Uses the classic data-flow approach, rather than the compilcated Lengauer-Tarjan algorithm.
Expand Down
8 changes: 4 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def try_remove(path):

# Linkage emits target code.
ret = os.system(f"{debug} moonc link-core {bundled}/core.core build/{src}.core -o build/{dest} -pkg-config-path {src}/moon.pkg.json -pkg-sources {core}:{src} -target {target}")

# Remove intermediate files that we don't need.
try_remove(f"build/{src}.core")
try_remove(f"build/{src}.mi")

if ret != 0:
print("Compiler generated an error. Failed.")
Expand All @@ -83,10 +87,6 @@ def try_remove(path):
if args.wasm:
print("WASM target does not support testing. Exit.")
break;

# Remove intermediate files that we don't need.
try_remove(f"build/{src}.core")
try_remove(f"build/{src}.mi")

# Test.
if not args.compile_only:
Expand Down
1 change: 1 addition & 0 deletions test/src/match02/match02.ans
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
other
other
one or two
one or two
one or two
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 4 additions & 0 deletions test/src/match05/match05.ans
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Some(88)
Some(125)
None
Some(97)
36 changes: 36 additions & 0 deletions test/src/match05/match05.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
enum IntOption {
None
Some(Int)
} derive (Show)

fn map(x: IntOption, f: (Int) -> Int) -> IntOption {
match x {
Some(z) => Some(f(z))
_ => None
}
}

fn filter(x: IntOption, f: (Int) -> Bool) -> IntOption {
match x {
Some(z) => if f(z) { x } else { None }
_ => None
}
}

fn main {
fn make_adder(x) {
fn adder(y) { x + y }
adder
}

let option = Some(88);
println(option);

let add = make_adder(37);
println(map(option, add));

let another = make_adder(9);
let mapped = map(option, another);
println(filter(mapped, fn (x) { x > 100 }));
println(filter(mapped, fn (x) { x > 90 }));
}

0 comments on commit 269eb1b

Please sign in to comment.