From 0df31e6570d9d707c91493878e9082d8918a4118 Mon Sep 17 00:00:00 2001 From: AdUhTkJm <2292398666@qq.com> Date: Tue, 21 Jan 2025 15:20:06 +0000 Subject: [PATCH] Escape analysis --- src/riscv_opt_escape.ml | 161 ++++++++++++++++++++++++++++++++++++++++ src/riscv_opt_gather.ml | 1 + src/riscv_ssa.ml | 7 +- test/interpreter.cpp | 8 ++ 4 files changed, 176 insertions(+), 1 deletion(-) create mode 100644 src/riscv_opt_escape.ml diff --git a/src/riscv_opt_escape.ml b/src/riscv_opt_escape.ml new file mode 100644 index 0000000..ac3532b --- /dev/null +++ b/src/riscv_opt_escape.ml @@ -0,0 +1,161 @@ +(** +Does escape analysis, and put heap allocations to stack allocation / registers +based on the result. +*) +open Riscv_ssa +open Riscv_opt + +type escape_state = +| NoEscape (* Does not escape the function *) +| LocalEscape (* Escapes by getting captured by some closure *) +| GlobalEscape (* Escapes by storing into some place *) + +let join s1 s2 = match (s1, s2) with +| GlobalEscape, _ | _, GlobalEscape -> GlobalEscape +| LocalEscape, _ | _, LocalEscape -> LocalEscape +| _ -> NoEscape + +let print_escape = + Hashtbl.iter (fun var state -> Printf.printf "%s: %s\n" var (match state with + | NoEscape -> "no escape" + | LocalEscape -> "local escape" + | GlobalEscape -> "global escape")) + +let get_escape table (var: string) = + if not (Hashtbl.mem table var) then + Hashtbl.add table var NoEscape; + Hashtbl.find table var + + +(** +Does escape analysis. +This does not yet support analysis of LocalEscape; +every variable is categorized into either No- or GlobalEscape. +*) +let escape_analysis fn = + (* Do escape analysis in the data-flow way. *) + (* It's quite similar to liveness analysis in riscv_opt.ml. *) + let escape_in = Hashtbl.create 1024 in + let escape_out = Hashtbl.create 1024 in + + let blocks = get_blocks fn in + List.iter (fun name -> + Hashtbl.add escape_in name (Hashtbl.create 64); + Hashtbl.add escape_out name (Hashtbl.create 64); + ) blocks; + + let worklist = Basic_vec.of_list blocks in + while Basic_vec.length worklist != 0 do + let name = Basic_vec.pop worklist in + let block = block_of name in + + (* Escape_in should be the union of all escape_out *) + Basic_vec.iter (fun pred -> + let pred_out = Hashtbl.find escape_out pred in + let block_in = Hashtbl.find escape_in name in + Hashtbl.iter (fun var state -> + let existing = get_escape block_in var in + Hashtbl.replace block_in var (join existing state) + ) pred_out + ) block.pred; + + (* Now calculate escape_out based on it *) + let old_out = Hashtbl.find escape_out name in + let last_out = ref old_out in + let new_out = Hashtbl.copy old_out in + let changed = ref true in + + let replace var state = + Hashtbl.replace new_out var.name state + in + + while !changed do + changed := false; + Basic_vec.iter (fun x -> match x with + | Assign { rd; rs } -> + replace rd (get_escape new_out rs.name) + + | AssignLabel { rd; _ } -> replace rd GlobalEscape + | Return x -> replace x GlobalEscape + + | Call { rd; args } + | CallExtern { rd; args } -> + List.iter (fun arg -> + replace arg GlobalEscape + ) args; + replace rd GlobalEscape + + | Store { rd; rs } + | Addi { rd; rs } -> + let ed = get_escape new_out rd.name in + let es = get_escape new_out rs.name in + let state = join ed es in + + replace rd state; + replace rs state + + | Add { rd; rs1; rs2 } + | Sub { rd; rs1; rs2 } -> + let ed = get_escape new_out rd.name in + let es1 = get_escape new_out rs1.name in + let es2 = get_escape new_out rs2.name in + let state = (join ed (join es1 es2)) in + + replace rd state; + replace rs1 state; + replace rs2 state + + | Phi { rd; rs } -> + let state = + List.fold_left (fun acc (var, _) -> + join acc (get_escape new_out var.name) + ) NoEscape rs + in + replace rd state; + List.iter (fun (var, _) -> replace var state) rs + + | _ -> ()) block.body; + + Hashtbl.iter (fun var state -> + if state != get_escape !last_out var then + changed := true + ) new_out; + last_out := new_out; + done; + + (* If anything changes, put it back to queue *) + let changed = ref false in + Hashtbl.iter (fun var state -> + if state != get_escape old_out var then + changed := true + ) new_out; + + (* Note this `!` does not mean not *) + if !changed then ( + Hashtbl.replace escape_out name new_out; + Basic_vec.iter (fun x -> Basic_vec.push worklist x) block.succ + ) + done; + + escape_out + +(** Reforms `malloc` on heap to `alloca` on stack when possible. *) +let malloc_to_alloca fn = + let blocks = get_blocks fn in + let escape_data = escape_analysis fn in + List.iter (fun name -> + let block = block_of name in + let body = block.body |> Basic_vec.to_list in + let escaped = Hashtbl.find escape_data name in + let changed = List.map (fun x -> match x with + | Malloc { rd; size } -> + if get_escape escaped rd.name = NoEscape then + Alloca { rd; size } + else + Malloc { rd; size } + | w -> w) body in + block.body <- changed |> Basic_vec.of_list + ) blocks + +let lower_malloc ssa = + iter_fn malloc_to_alloca ssa \ No newline at end of file diff --git a/src/riscv_opt_gather.ml b/src/riscv_opt_gather.ml index 4406bfa..53a6baa 100644 --- a/src/riscv_opt_gather.ml +++ b/src/riscv_opt_gather.ml @@ -11,6 +11,7 @@ let opt ssa = for i = 1 to 3 do Riscv_opt_inline.inline ssa; Riscv_opt_peephole.peephole ssa; + Riscv_opt_escape.lower_malloc ssa; done; let s = map_fn ssa_of_cfg ssa in diff --git a/src/riscv_ssa.ml b/src/riscv_ssa.ml index 5954e49..41f8a32 100644 --- a/src/riscv_ssa.ml +++ b/src/riscv_ssa.ml @@ -216,7 +216,8 @@ and t = | ExtArray of extern_array (* An array in `.data` section *) | CallExtern of call_data (* Call a C function *) | CallIndirect of call_indirect (* Call a function pointer *) -| Malloc of malloc +| Malloc of malloc (* Allocate on heap *) +| Alloca of malloc (* Allocate on stack *) | Return of var (* Note: *) @@ -434,6 +435,9 @@ let to_string t = | Malloc { rd; size } -> Printf.sprintf "malloc %s %d" rd.name size + + | Alloca { rd; size } -> + Printf.sprintf "alloca %s %d" rd.name size | FnDecl { fn; args; body; } -> let args_str = String.concat ", " (List.map (fun x -> x.name) args) in @@ -511,6 +515,7 @@ let rec reg_map fd fs t = match t with | GlobalVarDecl var -> GlobalVarDecl var | ExtArray arr -> ExtArray arr | Malloc { rd; size } -> Malloc { rd = fd rd; size } +| Alloca { rd; size } -> Alloca { rd = fd rd; size } | Return var -> Return (fs var) let reg_iter fd fs t = diff --git a/test/interpreter.cpp b/test/interpreter.cpp index f456fb2..1db6f88 100644 --- a/test/interpreter.cpp +++ b/test/interpreter.cpp @@ -333,6 +333,14 @@ int64_t interpret(std::string label) { continue; } + if (op == "alloca") { + auto len = int_of(args[2]); + + VAL(1) = (int64_t) alloca(len); + OUTPUT(args[1], VAL(1)); + continue; + } + if (op == "phi") { bool is_bad = true;