From 820f1a2b1dd22a10096352d274b17398bdaae76d Mon Sep 17 00:00:00 2001 From: zmr233 Date: Fri, 24 Jan 2025 13:59:42 +0800 Subject: [PATCH 1/2] RPO Module --- src/label.ml | 45 ++- src/riscv_reg.ml | 392 +++++++++++----------- src/riscv_reg_alloc.ml | 13 +- src/riscv_reg_util.ml | 32 ++ src/riscv_virtasm.ml | 604 ++++++++++++++++++---------------- src/riscv_virtasm_generate.ml | 2 +- 6 files changed, 606 insertions(+), 482 deletions(-) create mode 100644 src/riscv_reg_util.ml diff --git a/src/label.ml b/src/label.ml index ccff07a..7b0e7e5 100644 --- a/src/label.ml +++ b/src/label.ml @@ -12,9 +12,11 @@ . *) - module Label = struct - type t = { name : string; [@ceh.ignore] stamp : int } + type t = + { name : string [@ceh.ignore] + ; stamp : int + } include struct let _ = fun (_ : t) -> () @@ -24,33 +26,36 @@ module Label = struct let bnds__001_ = ([] : _ Stdlib.List.t) in let bnds__001_ = let arg__005_ = Moon_sexp_conv.sexp_of_int stamp__004_ in - (S.List [ S.Atom "stamp"; arg__005_ ] :: bnds__001_ - : _ Stdlib.List.t) + (S.List [ S.Atom "stamp"; arg__005_ ] :: bnds__001_ : _ Stdlib.List.t) in let bnds__001_ = let arg__003_ = Moon_sexp_conv.sexp_of_string name__002_ in (S.List [ S.Atom "name"; arg__003_ ] :: bnds__001_ : _ Stdlib.List.t) in S.List bnds__001_ - : t -> S.t) + : t -> S.t) + ;; let _ = sexp_of_t let equal = (fun a__006_ b__007_ -> - if Stdlib.( == ) a__006_ b__007_ then true + if Stdlib.( == ) a__006_ b__007_ + then true else Stdlib.( = ) (a__006_.stamp : int) b__007_.stamp - : t -> t -> bool) + : t -> t -> bool) + ;; let _ = equal let (hash_fold_t : Ppx_base.state -> t -> Ppx_base.state) = - fun hsv arg -> + fun hsv arg -> let hsv = let hsv = hsv in hsv in Ppx_base.hash_fold_int hsv arg.stamp + ;; let _ = hash_fold_t @@ -61,14 +66,17 @@ module Label = struct hash_fold_t hsv arg) in fun x -> func x + ;; let _ = hash let compare = (fun a__008_ b__009_ -> - if Stdlib.( == ) a__008_ b__009_ then 0 + if Stdlib.( == ) a__008_ b__009_ + then 0 else Stdlib.compare (a__008_.stamp : int) b__009_.stamp - : t -> t -> int) + : t -> t -> int) + ;; let _ = compare end @@ -82,14 +90,31 @@ let rename t = { name = t.name; stamp = Basic_uuid.next () } let to_wasm_name (t : t) = Stdlib.String.concat "" [ "$"; t.name; "/"; Int.to_string t.stamp ] +;; let to_wasm_label_loop t = let x = t.stamp in ("$loop:" ^ Int.to_string x : Stdlib.String.t) +;; let to_wasm_label_break t = let x = t.stamp in ("$break:" ^ Int.to_string x : Stdlib.String.t) +;; + +(** Used for generating function label in RISCV asm. *) +let to_riscv_label_func t = + let fun_name = t.name in + let fun_version = t.stamp in + ("_" ^ fun_name ^ Int.to_string fun_version : Stdlib.String.t) +;; + +(** Used for generating block label in RISCV asm. *) +let to_riscv_label_block t = + let bl_name = t.name in + let bl_num = t.stamp in + ("." ^ bl_name ^ Int.to_string bl_num : Stdlib.String.t) +;; module Hash = Basic_hashf.Make (Label) module Hashset = Basic_hashsetf.Make (Label) diff --git a/src/riscv_reg.ml b/src/riscv_reg.ml index 441eebc..4f5b3e7 100644 --- a/src/riscv_reg.ml +++ b/src/riscv_reg.ml @@ -1,16 +1,170 @@ (** Registers for RV64GC **) -(* Define a label type*) -type label_t = string +(* Module for physical registers (reg_t) *) +module Reg = struct + type t = + | Zero (* zero register *) + | Ra (* caller return address *) + | Sp (* caller (S0) stack pointer *) + | Gp (* global pointer *) + | Tp (* thread pointer *) + | T0 (* caller temporary register *) + | T1 + | T2 + | Fp (* callee stack bottom register *) + | S1 (* callee saved register *) + | A0 (* caller argument register *) + | A1 + | A2 + | A3 + | A4 + | A5 + | A6 + | A7 + | S2 (* callee saved register *) + | S3 + | S4 + | S5 + | S6 + | S7 + | S8 + | S9 + | S10 + | S11 + | T3 (* caller temporary register *) + | T4 + | T5 + | T6 (* caller swap register *) + + (* Convert physical register to string *) + let to_string r = + match r with + | Zero -> "zero" + | Ra -> "ra" + | Sp -> "sp" + | Gp -> "gp" + | Tp -> "tp" + | T0 -> "t0" + | T1 -> "t1" + | T2 -> "t2" + | Fp -> "fp" + | S1 -> "s1" + | A0 -> "a0" + | A1 -> "a1" + | A2 -> "a2" + | A3 -> "a3" + | A4 -> "a4" + | A5 -> "a5" + | A6 -> "a6" + | A7 -> "a7" + | S2 -> "s2" + | S3 -> "s3" + | S4 -> "s4" + | S5 -> "s5" + | S6 -> "s6" + | S7 -> "s7" + | S8 -> "s8" + | S9 -> "s9" + | S10 -> "s10" + | S11 -> "s11" + | T3 -> "t3" + | T4 -> "t4" + | T5 -> "t5" + | T6 -> "t6" + ;; +end + +(* Module for floating-point registers (freg_t) *) +module FReg = struct + type t = + | Ft0 (* caller floating-point temporary register *) + | Ft1 + | Ft2 + | Ft3 + | Ft4 + | Ft5 + | Ft6 + | Ft7 + | Fs0 (* callee floating-point saved register *) + | Fs1 + | Fa0 (* caller floating-point argument register *) + | Fa1 + | Fa2 + | Fa3 + | Fa4 + | Fa5 + | Fa6 + | Fa7 + | Fs2 (* callee floating-point saved register *) + | Fs3 + | Fs4 + | Fs5 + | Fs6 + | Fs7 + | Fs8 + | Fs9 + | Fs10 + | Fs11 + | Ft8 + | Ft9 + | Ft10 + | Ft11 (* caller swap floating-point register *) + + (* Convert floating-point register to string *) + let to_string fr = + match fr with + | Ft0 -> "ft0" + | Ft1 -> "ft1" + | Ft2 -> "ft2" + | Ft3 -> "ft3" + | Ft4 -> "ft4" + | Ft5 -> "ft5" + | Ft6 -> "ft6" + | Ft7 -> "ft7" + | Fs0 -> "fs0" + | Fs1 -> "fs1" + | Fa0 -> "fa0" + | Fa1 -> "fa1" + | Fa2 -> "fa2" + | Fa3 -> "fa3" + | Fa4 -> "fa4" + | Fa5 -> "fa5" + | Fa6 -> "fa6" + | Fa7 -> "fa7" + | Fs2 -> "fs2" + | Fs3 -> "fs3" + | Fs4 -> "fs4" + | Fs5 -> "fs5" + | Fs6 -> "fs6" + | Fs7 -> "fs7" + | Fs8 -> "fs8" + | Fs9 -> "fs9" + | Fs10 -> "fs10" + | Fs11 -> "fs11" + | Ft8 -> "ft8" + | Ft9 -> "ft9" + | Ft10 -> "ft10" + | Ft11 -> "ft11" + ;; +end + (** Defines an immediate value type. Immediate values can either be an integer (`IntImm`) or a floating-point number (`FloatImm`). *) -type imm_t = - | IntImm of int (* Integer immediate value *) - | FloatImm of float (* Floating-point immediate value *) +module Imm = struct + type t = + | IntImm of int (* Integer immediate value *) + | FloatImm of float (* Floating-point immediate value *) + + let to_string imm = + match imm with + | IntImm i -> string_of_int i + | FloatImm f -> string_of_float f + ;; +end (** Defines a slot type for both virtual and physical registers. @@ -18,182 +172,52 @@ Defines a slot type for both virtual and physical registers. This type encapsulates different kinds of registers, including general-purpose registers, floating-point registers, and specific slots representing values like `Unit` (no return value). *) -type slot_t = - | Unit (* Represents no return value, often used in function calls or returns *) - | Slot of int - | FSlot of int - | Reg of reg_t - | FReg of freg_t - -and ret_type = - | IntRet - | FloatRet - | UnitRet - -and reg_t = - | Zero (* zero register *) - | Ra (* caller return address *) - | Sp (* caller (S0) stack pointer *) - | Gp (* global pointer *) - | Tp (* thread pointer *) - | T0 (* caller temporary register *) - | T1 - | T2 - | Fp (* callee stack bottom register *) - | S1 (* callee saved register *) - | A0 (* caller argument register *) - | A1 - | A2 - | A3 - | A4 - | A5 - | A6 - | A7 - | S2 (* callee saved register *) - | S3 - | S4 - | S5 - | S6 - | S7 - | S8 - | S9 - | S10 - | S11 - | T3 (* caller temporary register *) - | T4 - | T5 - | T6 (* caller swap register *) - -and freg_t = - | Ft0 (* caller floating-point temporary register *) - | Ft1 - | Ft2 - | Ft3 - | Ft4 - | Ft5 - | Ft6 - | Ft7 - | Fs0 (* callee floating-point saved register *) - | Fs1 - | Fa0 (* caller floating-point argument register *) - | Fa1 - | Fa2 - | Fa3 - | Fa4 - | Fa5 - | Fa6 - | Fa7 - | Fs2 (* callee floating-point saved register *) - | Fs3 - | Fs4 - | Fs5 - | Fs6 - | Fs7 - | Fs8 - | Fs9 - | Fs10 - | Fs11 - | Ft8 - | Ft9 - | Ft10 - | Ft11 (* caller swap floating-point register *) - -(* Convert reg_t to string representation *) -let reg_to_string r = - match r with - | Zero -> "zero" - | Ra -> "ra" - | Sp -> "sp" - | Gp -> "gp" - | Tp -> "tp" - | T0 -> "t0" - | T1 -> "t1" - | T2 -> "t2" - | Fp -> "s0" - | S1 -> "s1" - | A0 -> "a0" - | A1 -> "a1" - | A2 -> "a2" - | A3 -> "a3" - | A4 -> "a4" - | A5 -> "a5" - | A6 -> "a6" - | A7 -> "a7" - | S2 -> "s2" - | S3 -> "s3" - | S4 -> "s4" - | S5 -> "s5" - | S6 -> "s6" - | S7 -> "s7" - | S8 -> "s8" - | S9 -> "s9" - | S10 -> "s10" - | S11 -> "s11" - | T3 -> "t3" - | T4 -> "t4" - | T5 -> "t5" - | T6 -> "t6" -;; - -(* Convert freg_t to string representation *) -let freg_to_string fr = - match fr with - | Ft0 -> "ft0" - | Ft1 -> "ft1" - | Ft2 -> "ft2" - | Ft3 -> "ft3" - | Ft4 -> "ft4" - | Ft5 -> "ft5" - | Ft6 -> "ft6" - | Ft7 -> "ft7" - | Fs0 -> "fs0" - | Fs1 -> "fs1" - | Fa0 -> "fa0" - | Fa1 -> "fa1" - | Fa2 -> "fa2" - | Fa3 -> "fa3" - | Fa4 -> "fa4" - | Fa5 -> "fa5" - | Fa6 -> "fa6" - | Fa7 -> "fa7" - | Fs2 -> "fs2" - | Fs3 -> "fs3" - | Fs4 -> "fs4" - | Fs5 -> "fs5" - | Fs6 -> "fs6" - | Fs7 -> "fs7" - | Fs8 -> "fs8" - | Fs9 -> "fs9" - | Fs10 -> "fs10" - | Fs11 -> "fs11" - | Ft8 -> "ft8" - | Ft9 -> "ft9" - | Ft10 -> "ft10" - | Ft11 -> "ft11" -;; - -let to_string (s : slot_t) : string = - match s with - | Slot i -> Printf.sprintf "%%%d" i - | FSlot i -> Printf.sprintf "%%f%d" i - | Reg r -> reg_to_string r - | FReg fr -> freg_to_string fr - | Unit -> "_" -;; - -(** Counter of temporaries. *) -let slot_cnt = ref 0 - -let fslot_cnt = ref 0 - -let new_slot () = - let i = !slot_cnt in - slot_cnt := i + 1; - Slot i -;; - -let new_fslot () = - let i = !fslot_cnt in - fslot_cnt := i + 1; - FSlot i -;; +module Slot = struct + (* Slot type, which can include different kinds of registers *) + type t = + | Unit (* No return value, used in function calls or returns *) + | Slot of int (* Integer register slot *) + | FSlot of int (* Floating-point register slot *) + | Reg of Reg.t (* Physical register *) + | FReg of FReg.t (* Floating-point physical register *) + + include struct + let _ = fun (_ : t) -> () + let compare = compare + let equal = (=) + let hash = Hashtbl.hash + end + + (* Convert t to string representation *) + let to_string (s : t) : string = + match s with + | Slot i -> Printf.sprintf "%%%d" i + | FSlot i -> Printf.sprintf "%%f%d" i + | Reg r -> Reg.to_string r + | FReg fr -> FReg.to_string fr + | Unit -> "_" + ;; + + (* Counter for integer slots *) + let slot_cnt = ref 0 + + (* Counter for floating-point slots *) + let fslot_cnt = ref 0 + + (* Create a new integer slot *) + let new_slot () = + let i = !slot_cnt in + slot_cnt := i + 1; + Slot i + ;; + + (* Create a new floating-point slot *) + let new_fslot () = + let i = !fslot_cnt in + fslot_cnt := i + 1; + FSlot i + ;; +end + +(* Key t*) +module SlotSet = Basic_hashsetf.Make (Slot) diff --git a/src/riscv_reg_alloc.ml b/src/riscv_reg_alloc.ml index 69a704a..eefe338 100644 --- a/src/riscv_reg_alloc.ml +++ b/src/riscv_reg_alloc.ml @@ -1,12 +1,9 @@ +open Riscv_reg open Riscv_virtasm +open Riscv_reg_util -let reg_alloc (vprog : vprog_t) = - let vprog : vprog_t = - { blocks = VBlockMap.empty - ; funcs = VFuncMap.empty - ; consts = VSymbolMap.empty - ; loop_vars = VBlockMap.empty - } - in + +let reg_alloc (vprog: VProg.t) = + let rpo = RPO.calculate_rpo vprog in vprog ;; diff --git a/src/riscv_reg_util.ml b/src/riscv_reg_util.ml new file mode 100644 index 0000000..c54b88b --- /dev/null +++ b/src/riscv_reg_util.ml @@ -0,0 +1,32 @@ +open Riscv_reg +open Riscv_virtasm + +(** RPO, Reverse Postorder used for + RPO (Reverse Postorder) is an ordering of basic blocks in a control flow graph, + used to respect control flow during program analysis and optimization. +*) +module RPO = struct + type t = VBlockLabel.t list VFuncMap.t + + let calculate_rpo (vprog : VProg.t) = + let visited = VBlockSet.create 128 in + let cal_func_rpo (funn : VFuncLabel.t) (func : VFunc.t) (acc : t) : t = + let order = Vec.empty () in + let rec dfs (bl : VBlockLabel.t) = + if VBlockSet.mem visited bl + then () + else ( + VBlockSet.add visited bl; + let block = VProg.get_block vprog bl in + let succs = VBlock.get_successors block in + List.iter dfs succs; + Vec.push order bl) + in + dfs func.entry; + order |> Vec.to_list |> List.rev |> VFuncMap.add acc funn + in + VFuncMap.fold vprog.funcs VFuncMap.empty cal_func_rpo + ;; +end + +(* let cal_next_use_distance (vprog: VProg.t) (rpo: RPO.t) = *) diff --git a/src/riscv_virtasm.ml b/src/riscv_virtasm.ml index 8c991b7..43c9f85 100644 --- a/src/riscv_virtasm.ml +++ b/src/riscv_virtasm.ml @@ -23,298 +23,344 @@ let deblist (listn : string) (f : 'a -> string) (lst : 'a list) : unit = print_endline @@ "]" ;; -(** Similar to R-type instructions in RISC-V. *) -type r_slot = - { rd : slot_t - ; rs1 : slot_t - ; rs2 : slot_t - } - -(** R-type instructions for floating-point registers. *) -type r_fslot = - { frd : slot_t - ; frs1 : slot_t - ; frs2 : slot_t - } - -(** I-type, with one destination register, one source and one immediate. *) -type i_slot = - { rd : slot_t - ; rs1 : slot_t - ; imm : imm_t - } - -(** Defines a single floating-point register assignment with a destination register `frd`. *) -type single_fslot = { frd : slot_t } - -(** Defines a direct assignment between general-purpose and floating-point register.*) -type assign_direct = - { frd : slot_t - ; rs : slot_t - } - -type assign_slot = - { rd : slot_t - ; rs : slot_t - } - -type assign_fslot = - { frd : slot_t - ; frs : slot_t - } - -(** For special floating-point operation*) -type triple_fslot = - { frd : slot_t - ; frs1 : slot_t - ; frs2 : slot_t - ; frs3 : slot_t - } - -(** Immediate value `imm` to the destination register `rd`. *) -type assign_int64 = - { rd : slot_t - ; imm : imm_t - } - -(** Defines an assignment of a label (address or function) to a register `rd`. *) -type assign_label = - { rd : slot_t - ; label : label_t - } - -(** -Defines a conversion between floating-point and integer registers. - -Converts the value in the source floating-point register `frs` to the destination integer register `rd`. -*) -type convert_slot = - { rd : slot_t - ; frs : slot_t - } - -(** -Defines a conversion from an integer register to a floating-point register. - -Converts the integer value from `rs` to the destination floating-point register `frd`. -*) -type convert_fslot = - { frd : slot_t - ; rs : slot_t - } - -(** -Defines a comparison between two floating-point registers `frs1` and `frs2`, - with the result stored in the destination register `rd`. -*) -type compare_fslot = - { rd : slot_t - ; frs1 : slot_t - ; frs2 : slot_t - } - -(** Calls function named `fn` with arguments `args`, and store the result in `rd`. *) -type call_data = - { rd : slot_t - ; fn : label_t - ; args : slot_t list - ; fargs : slot_t list - } - -(** Call function pointer with address `rs` and arguments `args`, and returns in `rd` *) -type call_indirect = - { rd : slot_t - ; fn : slot_t - ; args : slot_t list - ; fargs : slot_t list - } - -(** -Similar to `ld` and `st` in RISC-V. - -`rd` and `rs` have different meanings in loads and stores: -We load `byte` bytes from `rs` into `rd`, -and store `byte` bytes from `rd` into `rs`. -*) -type mem_slot = - { rd : slot_t - ; base : slot_t - ; offset : imm_t - } - -type mem_fslot = - { frd : slot_t - ; base : slot_t - ; offset : imm_t - } - -(** -Defines a stack slot used for register spilling and reloading. -The `target` is the register being spilled/reloaded, +(** Slot types for different operations (integer, floating-point, etc.) *) +module Slots = struct + (** Similar to R-type instructions in RISC-V. *) + type r_slot = + { rd : Slot.t + ; rs1 : Slot.t + ; rs2 : Slot.t + } + + (** R-type instructions for floating-point registers. *) + type r_fslot = + { frd : Slot.t + ; frs1 : Slot.t + ; frs2 : Slot.t + } + + (** I-type, with one destination register, one source and one immediate. *) + type i_slot = + { rd : Slot.t + ; rs1 : Slot.t + ; imm : Imm.t + } + + (** Defines a single floating-point register assignment with a destination register `frd`. *) + type single_fslot = { frd : Slot.t } + + (** Defines a direct assignment between general-purpose and floating-point register.*) + type assign_direct = + { frd : Slot.t + ; rs : Slot.t + } + + type assign_slot = + { rd : Slot.t + ; rs : Slot.t + } + + type assign_fslot = + { frd : Slot.t + ; frs : Slot.t + } + + (** For special floating-point operation*) + type triple_fslot = + { frd : Slot.t + ; frs1 : Slot.t + ; frs2 : Slot.t + ; frs3 : Slot.t + } + + (** Immediate value `imm` to the destination register `rd`. *) + type assign_int64 = + { rd : Slot.t + ; imm : Imm.t + } + + (** Defines an assignment of a label (address or function) to a register `rd`. *) + type assign_label = + { rd : Slot.t + ; label : Label.t + } + + (** + Defines a conversion between floating-point and integer registers. + Converts floating-point register `frs` to the destination integer register `rd`. + *) + type convert_slot = + { rd : Slot.t + ; frs : Slot.t + } + + (** + Defines a conversion from an integer register to a floating-point register. + Converts the integer value from `rs` to the destination floating-point register `frd`. + *) + type convert_fslot = + { frd : Slot.t + ; rs : Slot.t + } + + (** + Defines a comparison between two floating-point registers `frs1` and `frs2`, + with the result stored in the destination register `rd`. + *) + type compare_fslot = + { rd : Slot.t + ; frs1 : Slot.t + ; frs2 : Slot.t + } + + (** Calls function named `fn` with arguments `args`, and store the result in `rd`. *) + type call_data = + { rd : Slot.t + ; fn : Label.t + ; args : Slot.t list + ; fargs : Slot.t list + } + + (** Call function pointer with address `rs` and arguments `args`, and returns in `rd` *) + type call_indirect = + { rd : Slot.t + ; fn : Slot.t + ; args : Slot.t list + ; fargs : Slot.t list + } + + (** + Similar to `ld` and `st` in RISC-V. + `rd` and `rs` have different meanings in loads and stores: + We load `byte` bytes from `rs` into `rd`, + and store `byte` bytes from `rd` into `rs`. + *) + type mem_slot = + { rd : Slot.t + ; base : Slot.t + ; offset : Imm.t + } + + type mem_fslot = + { frd : Slot.t + ; base : Slot.t + ; offset : Imm.t + } + + (** + Defines a stack slot used for register spilling and reloading. + The `target` is the register being spilled/reloaded, and `origin` is the original register it corresponds to. -Used in the final assembly generation to manage stack offsets. -*) + Used in the final assembly generation to manage stack offsets. + *) -type stack_slot = - { target : slot_t - ; origin : slot_t - } + type stack_slot = + { target : Slot.t + ; origin : Slot.t + } -type stack_fslot = - { target : slot_t - ; origin : slot_t - } + type stack_fslot = + { target : Slot.t + ; origin : Slot.t + } +end (** Virtual RISC-V Instructions *) -type t = - (* Integer Arithmetic Instructions *) - | Add of r_slot - | Sub of r_slot - | Addi of i_slot - (* Logical and Shift Instructions *) - | And of r_slot - | Or of r_slot - | Xor of r_slot - | Sll of r_slot (* shift left logical *) - | Srl of r_slot (* shift right logical *) - | Sra of r_slot (* shift right arithmetic *) - | Slli of i_slot (* shift left logical immediate *) - | Srli of i_slot (* shift right logical immediate *) - | Srai of i_slot (* shift right arithmetic immediate *) - (* Multiplication and Division Instructions *) - | Mul of r_slot - | Div of r_slot (* signed divide *) - | Divu of r_slot (* unsigned divide *) - | Rem of r_slot (* signed remainder *) - | Remu of r_slot (* unsigned remainder *) - (* Memory Access Instructions *) - | Lw of mem_slot (* load word 32-bit *) - | Ld of mem_slot (* load doubleword 64-bit *) - | Sw of mem_slot (* store word 32-bit *) - | Sd of mem_slot (* store doubleword 64-bit *) - (* Floating-Point Arithmetic Instructions *) - | FaddD of r_fslot - | FsubD of r_fslot - | FmulD of r_fslot - | FdivD of r_fslot - | FmaddD of triple_fslot (* fmadd.d => f[rd] = f[rs1]×f[rs2]+f[rs3] *) - | FmsubD of triple_fslot (* fmsub.d => f[rd] = f[rs1]×f[rs2]-f[rs3] *) - | FnmaddD of triple_fslot (* fnmadd.d => f[rd] = -f[rs1]×f[rs2]+f[rs3] *) - | FnmsubD of triple_fslot (* fnmsub.d => f[rd] = -f[rs1]×f[rs2]-f[rs3] *) - (* Floating-Point Compare Instructions *) - | FeqD of compare_fslot (* == *) - | FltD of compare_fslot (* < *) - | FleD of compare_fslot (* <= *) - (* Floating-Point Conversion *) - | FcvtDW of convert_fslot (* convert int32 to float *) - | FcvtDL of convert_fslot (* convert int64 to float *) - | FcvtLD of convert_slot (* convert float to int64 *) - | FcvtWDRtz of convert_slot (* convert float to int, round towards zero *) - (* Floating-Point Misc Instructions *) - | FsqrtD of assign_fslot (* square root *) - | FabsD of assign_fslot (* absolute value *) - (* Floating-Point Memory Instructions *) - | Fld of mem_fslot (* load doubleword 64-bit *) - | Fsd of mem_fslot (* store doubleword 64-bit *) - (* Movement Instructions *) - | La of assign_label (* load address *) - | Li of assign_int64 (* load immediate *) - | Neg of assign_slot - | Mv of assign_slot - | FnegD of assign_fslot - | FmvD of assign_fslot - | FmvDX of assign_direct (* move integer slot -> float slot (bitwise) *) - | FmvDXZero of single_fslot (* move x0 -> float slot (bitwise), i.e. 0.0 *) - (* Call / Function Invocation Instructions *) - | Call of call_data - | CallIndirect of call_indirect - (* Register Allocation Directives *) - | Spill of stack_slot - | Reload of stack_slot - | FSpill of stack_fslot - | FReload of stack_fslot - -(** Branching is done based on the comparison of registers `rs1`, `rs2` or a single register `rs`. *) -type branch_slot = - { rs1 : slot_t - ; rs2 : slot_t - ; ifso : label_t - ; ifnot : label_t - } - -(** rd stores return address, label is the jump target *) -type jal_label = - { rd : slot_t - ; label : label_t - } - -(** jump address is calculated by rs1 + offset, rd stores return address *) -type jalr_label = - { rd : slot_t - ; rs1 : slot_t - ; offset : imm_t - } - -(** These include conditional branches, unconditional jumps, function returns, and tail calls. *) -type term_t = - | Beq of branch_slot (* Branch if equal *) - | Bne of branch_slot (* Branch if not equal *) - | Blt of branch_slot (* Branch if less than *) - | Bge of branch_slot (* Branch if greater than or equal *) - | Bltu of branch_slot (* Branch if less than unsigned *) - | Bgeu of branch_slot (* Branch if greater than or equal unsigned *) - | Jal of label_t (* jump and link (store return address) *) - | Jalr of jalr_label (* jump and link register (store return address) *) - | TailCall of call_data - | TailCallIndirect of call_indirect - | Ret of slot_t (* Unit for no return*) - -(* Note: *) -(* Riscv Virtual ASM still retains the structure of control flow (CFG), *) -(* while its VirtualASM instructions are closer to real assembly. *) -(* It also includes pseudo-instructions for convenient register allocation and defines the slot_t type, *) -(* which aims to allow virtual registers of Slots to coexist with real registers of Regs. *) - -(* Key int*) -module IntMap = Basic_map_int -module VBlockMap = Basic_map_int -module VSymbolMap = Basic_map_int - -(* Key string*) -module StringMap = Basic_map_string -module VFuncMap = Basic_map_string +module Inst = struct + open Slots + + type t = + (* Integer Arithmetic Instructions *) + | Add of r_slot + | Sub of r_slot + | Addi of i_slot + (* Logical and Shift Instructions *) + | And of r_slot + | Or of r_slot + | Xor of r_slot + | Sll of r_slot (* shift left logical *) + | Srl of r_slot (* shift right logical *) + | Sra of r_slot (* shift right arithmetic *) + | Slli of i_slot (* shift left logical immediate *) + | Srli of i_slot (* shift right logical immediate *) + | Srai of i_slot (* shift right arithmetic immediate *) + (* Multiplication and Division Instructions *) + | Mul of r_slot + | Div of r_slot (* signed divide *) + | Divu of r_slot (* unsigned divide *) + | Rem of r_slot (* signed remainder *) + | Remu of r_slot (* unsigned remainder *) + (* Memory Access Instructions *) + | Lw of mem_slot (* load word 32-bit *) + | Ld of mem_slot (* load doubleword 64-bit *) + | Sw of mem_slot (* store word 32-bit *) + | Sd of mem_slot (* store doubleword 64-bit *) + (* Floating-Point Arithmetic Instructions *) + | FaddD of r_fslot + | FsubD of r_fslot + | FmulD of r_fslot + | FdivD of r_fslot + | FmaddD of triple_fslot (* fmadd.d => f[rd] = f[rs1]×f[rs2]+f[rs3] *) + | FmsubD of triple_fslot (* fmsub.d => f[rd] = f[rs1]×f[rs2]-f[rs3] *) + | FnmaddD of triple_fslot (* fnmadd.d => f[rd] = -f[rs1]×f[rs2]+f[rs3] *) + | FnmsubD of triple_fslot (* fnmsub.d => f[rd] = -f[rs1]×f[rs2]-f[rs3] *) + (* Floating-Point Compare Instructions *) + | FeqD of compare_fslot (* == *) + | FltD of compare_fslot (* < *) + | FleD of compare_fslot (* <= *) + (* Floating-Point Conversion *) + | FcvtDW of convert_fslot (* convert int32 to float *) + | FcvtDL of convert_fslot (* convert int64 to float *) + | FcvtLD of convert_slot (* convert float to int64 *) + | FcvtWDRtz of convert_slot (* convert float to int, round towards zero *) + (* Floating-Point Misc Instructions *) + | FsqrtD of assign_fslot (* square root *) + | FabsD of assign_fslot (* absolute value *) + (* Floating-Point Memory Instructions *) + | Fld of mem_fslot (* load doubleword 64-bit *) + | Fsd of mem_fslot (* store doubleword 64-bit *) + (* Movement Instructions *) + | La of assign_label (* load address *) + | Li of assign_int64 (* load immediate *) + | Neg of assign_slot + | Mv of assign_slot + | FnegD of assign_fslot + | FmvD of assign_fslot + | FmvDX of assign_direct (* move integer slot -> float slot (bitwise) *) + | FmvDXZero of single_fslot (* move x0 -> float slot (bitwise), i.e. 0.0 *) + (* Call / Function Invocation Instructions *) + | Call of call_data + | CallIndirect of call_indirect + (* Register Allocation Directives *) + | Spill of stack_slot + | Reload of stack_slot + | FSpill of stack_fslot + | FReload of stack_fslot +end (* Vector alias*) module Vec = Basic_vec -type vblock_label = int -type vfunc_label = string -type vsymbol_label = int - -(* Count for VirtSymbol*) -let vsymbol_cnt = ref 0 +(* VBlock *) +module VBlockLabel = Label +module VBlockSet = Label.Hashset +module VBlockMap = Label.Map + +(* VFunc *) +module VFuncLabel = Label +module VFuncSet = Label.Hashset +module VFuncMap = Label.Map + +(* VProg *) + +(* VSymbol *) +module VSymbolLabel = Label +module VSymbolSet = Label.Hashset +module VSymbolMap = Label.Map + +(** Control Flow module *) +module Term = struct + open Slots + + (** Branch slot for conditional branches *) + type branch_slot = + { rs1 : Slot.t + ; rs2 : Slot.t + ; ifso : VBlockLabel.t + ; ifnot : VBlockLabel.t + } + + (** rd stores return address, label is the jump target *) + type jal_label = + { rd : Slot.t + ; label : VBlockLabel.t + } + + (** jump address is calculated by rs1 + offset, rd stores return address *) + type jalr_label = + { rd : Slot.t + ; rs1 : Slot.t + ; offset : Imm.t + } + + (** These include conditional branches, unconditional jumps, function returns, and tail calls. *) + type t = + | Beq of branch_slot (* Branch if equal *) + | Bne of branch_slot (* Branch if not equal *) + | Blt of branch_slot (* Branch if less than *) + | Bge of branch_slot (* Branch if greater than or equal *) + | Bltu of branch_slot (* Branch if less than unsigned *) + | Bgeu of branch_slot (* Branch if greater than or equal unsigned *) + | J of VBlockLabel.t (* jump (not stort return address) *) + | Jal of VBlockLabel.t (* jump and link (store return address) *) + | Jalr of jalr_label (* jump and link register (store return address) *) + | TailCall of call_data + | TailCallIndirect of call_indirect + | Ret of Slot.t (* Unit for no return*) +end (** VirtRvBlock*) -type vblock_t = - { body : t Vec.t - ; term : term_t (* Single Terminator*) - ; preds : vblock_label Vec.t (* Predecessors*) - } +module VBlock = struct + type t = + { body : t Vec.t + ; term : Term.t (* Single Terminator*) + ; preds : VBlockLabel.t Vec.t (* Predecessors*) + } + + let get_successors (block : t) : VBlockLabel.t list = + match block.term with + | Beq branch_slot + | Bne branch_slot + | Blt branch_slot + | Bge branch_slot + | Bltu branch_slot + | Bgeu branch_slot -> [ branch_slot.ifso; branch_slot.ifnot ] + | J label | Jal label -> [ label ] + | Jalr _ -> + [] (* Since Jalr is a computed jump, we might not have a static successor *) + | TailCall _ | TailCallIndirect _ -> + [] + (* Tail calls typically transfer control to another function, no direct successor *) + | Ret _ -> [] (* Return terminates the current function, so no successor blocks *) + ;; +end (** VirtRvFunc*) -type vfunc_t = - { result : ret_type option - ; args : slot_t list - ; fargs : slot_t list - ; entry : vblock_label - } +module VFunc = struct + type t = + { funn : VFuncLabel.t + ; args : Slot.t list + ; fargs : Slot.t list + ; entry : VBlockLabel.t + } +end (** VirtRvProg*) -type vprog_t = - { blocks : vblock_t VBlockMap.t - ; funcs : vfunc_t VFuncMap.t - ; consts : imm_t VSymbolMap.t - ; loop_vars : slot_t VBlockMap.t - (* Loop internal variables - +module VProg = struct + type t = + { blocks : VBlock.t VBlockMap.t + ; funcs : VFunc.t VFuncMap.t + ; consts : Imm.t VSymbolMap.t + ; loop_vars : Slot.t VBlockMap.t + (* Loop internal variables - used for register allocation special identification*) - } + } + + let get_block (vprog : t) (bl : VBlockLabel.t) : VBlock.t = + match VBlockMap.find_opt vprog.blocks bl with + | None -> failwith "get_block: block not found" + | Some x -> x + ;; + + let get_func (vprog : t) (fn : VFuncLabel.t) : VFunc.t = + match VFuncMap.find_opt vprog.funcs fn with + | None -> failwith "get_func: function not found" + | Some x -> x + ;; +end + +(* Note: *) +(* Riscv Virtual ASM still retains the structure of control flow (CFG), *) +(* while its VirtualASM instructions are closer to real assembly. *) +(* It also includes pseudo-instructions for convenient register allocation and defines the Slot.t type, *) +(* which aims to allow virtual registers of Slots to coexist with real registers of Regs. *) diff --git a/src/riscv_virtasm_generate.ml b/src/riscv_virtasm_generate.ml index 2fe4f67..f695bd9 100644 --- a/src/riscv_virtasm_generate.ml +++ b/src/riscv_virtasm_generate.ml @@ -2,7 +2,7 @@ open Riscv_virtasm module Ssa = Riscv_ssa let virtasm_of_ssa (ssa : Ssa.t list) = - let vprog : vprog_t = + let vprog : VProg.t = { blocks = VBlockMap.empty ; funcs = VFuncMap.empty ; consts = VSymbolMap.empty From b7c72929f2504f9346a144b15ef8d74e3a008855 Mon Sep 17 00:00:00 2001 From: zmr233 Date: Sat, 25 Jan 2025 15:06:01 +0800 Subject: [PATCH 2/2] Liveness & Next use distance analysis --- src/riscv_reg.ml | 86 +++++++++++++++-- src/riscv_reg_alloc.ml | 1 + src/riscv_reg_util.ml | 212 ++++++++++++++++++++++++++++++++++++++++- src/riscv_virtasm.ml | 79 ++++++++++++++- 4 files changed, 366 insertions(+), 12 deletions(-) diff --git a/src/riscv_reg.ml b/src/riscv_reg.ml index 4f5b3e7..3897864 100644 --- a/src/riscv_reg.ml +++ b/src/riscv_reg.ml @@ -148,7 +148,6 @@ module FReg = struct ;; end - (** Defines an immediate value type. @@ -181,13 +180,6 @@ module Slot = struct | Reg of Reg.t (* Physical register *) | FReg of FReg.t (* Floating-point physical register *) - include struct - let _ = fun (_ : t) -> () - let compare = compare - let equal = (=) - let hash = Hashtbl.hash - end - (* Convert t to string representation *) let to_string (s : t) : string = match s with @@ -198,6 +190,16 @@ module Slot = struct | Unit -> "_" ;; + include struct + let _ = fun (_ : t) -> () + let compare = compare + let equal = ( = ) + let hash = Hashtbl.hash + + (* Map function*) + let sexp_of_t (x : t) : S.t = Atom (to_string x) + end + (* Counter for integer slots *) let slot_cnt = ref 0 @@ -217,7 +219,71 @@ module Slot = struct fslot_cnt := i + 1; FSlot i ;; + + let is_int = function + | Slot _ -> true + | Reg _ -> true + | _ -> false + ;; + + let is_float = function + | FSlot _ -> true + | FReg _ -> true + | _ -> false + ;; +end + +module SlotSet = struct + include Basic_setf.Make (Slot) + + (* Clone function to create a deep copy of the SlotSet *) + let clone s = of_list (to_list s) + + (* A comprehensive and efficient equal function *) + let equal s1 s2 = + (* 1. Physical reference check: same object, directly equal *) + if s1 == s2 + then true + else if + (* 2. If number of elements differs, not equal *) + cardinal s1 <> cardinal s2 + then false + else + (* 3. Check if each element in s1 exists in s2 *) + for_all s1 (fun x -> mem s2 x) + ;; end -(* Key t*) -module SlotSet = Basic_hashsetf.Make (Slot) +module SlotMap = struct + include Basic_mapf.Make (Slot) + + (* Clone function to create a deep copy of the SlotMap *) + let clone m = of_array (to_sorted_array m) + + (* + [equal eqv m1 m2] + returns true if and only if: + 1) [m1] and [m2] contain exactly the same set of keys, and + 2) for every key k, values in m1 and m2 are considered equal by [eqv]. + *) + let equal (eqv : 'a -> 'a -> bool) (m1 : 'a t) (m2 : 'a t) : bool = + (* If number of elements differs, directly return false *) + if cardinal m1 <> cardinal m2 + then false + else ( + (* Otherwise iterate through m1 to check each key-value pair *) + try + iter m1 (fun k v1 -> + match find_opt m2 k with + | None -> + (* If key not found in m2, determine unequal and break *) + raise Exit + | Some v2 -> + (* If custom eqv function determines unequal values, break *) + if not (eqv v1 v2) then raise Exit); + (* If no mismatches found, return true *) + true + with + | Exit -> false) + ;; +end diff --git a/src/riscv_reg_alloc.ml b/src/riscv_reg_alloc.ml index eefe338..32b87ee 100644 --- a/src/riscv_reg_alloc.ml +++ b/src/riscv_reg_alloc.ml @@ -5,5 +5,6 @@ open Riscv_reg_util let reg_alloc (vprog: VProg.t) = let rpo = RPO.calculate_rpo vprog in + let liveinfo = Liveness.liveness_analysis vprog rpo in vprog ;; diff --git a/src/riscv_reg_util.ml b/src/riscv_reg_util.ml index c54b88b..c92e526 100644 --- a/src/riscv_reg_util.ml +++ b/src/riscv_reg_util.ml @@ -27,6 +27,216 @@ module RPO = struct in VFuncMap.fold vprog.funcs VFuncMap.empty cal_func_rpo ;; + + let get_func_rpo (funn : VFuncLabel.t) (rpo : t) : VBlockLabel.t list = + match VFuncMap.find_opt rpo funn with + | Some x -> x + | None -> failwith "RPO.get_func_rpo: function not found" + ;; end -(* let cal_next_use_distance (vprog: VProg.t) (rpo: RPO.t) = *) +module Liveness = struct + (** + For maxPressure_I / maxPressure_F, + we count the number of Int/Float slots by checking each slot in liveIn set. + *) + type live_info = + { maxPressure_I : int (* Maximum pressure of integer registers *) + ; maxPressure_F : int (* Maximum pressure of floating-point registers *) + ; liveIn : SlotSet.t + ; liveOut : SlotSet.t + ; exitNextUse : int SlotMap.t + } + + (** + Liveness information for all basic blocks in the program. + VBlockMap maps VBlockLabel.t to live_info + *) + type t = live_info VBlockMap.t + + (** + Get live_info for a specific block. + Returns default values or fails if block not found in map, depending on context + *) + let get_liveinfo (liveness : t) (bl : VBlockLabel.t) : live_info = + match VBlockMap.find_opt liveness bl with + | Some x -> x + | None -> + { maxPressure_I = 0 + ; maxPressure_F = 0 + ; liveIn = SlotSet.empty + ; liveOut = SlotSet.empty + ; exitNextUse = SlotMap.empty + } + ;; + + (** Helper function: checks if an int is "not infinite" *) + let not_inf (x : int) = x < max_int + + (** + Helper function: saturating addition + Returns x + y if both are not max_int, otherwise returns max_int + *) + let sat_add (x : int) (y : int) = if not_inf x && not_inf y then x + y else max_int + + (** + Helper function: increment all values in map by one (with saturation) + *) + let incr_all_values_by_one (mp : int SlotMap.t) : int SlotMap.t = + SlotMap.fold mp SlotMap.empty (fun slot dist acc -> + let new_dist = sat_add dist 1 in + SlotMap.add acc slot new_dist) + ;; + + (** + Helper function: count the number of int slots and float slots in liveIn set + Returns a tuple (int_count, float_count) + *) + let count_int_and_float (s : SlotSet.t) : int * int = + SlotSet.fold s (0, 0) (fun slot (cnt_i, cnt_f) -> + if Slot.is_int slot + then cnt_i + 1, cnt_f + else if Slot.is_float slot + then cnt_i, cnt_f + 1 + else cnt_i, cnt_f) + ;; + + (** + Main logic: Performs liveness analysis on the entire program + Returns a VBlockMap.t where each block label maps to its live_info + *) + let liveness_analysis (vprog : VProg.t) (rpo : RPO.t) : t = + (* Reference to store analysis results for all blocks, updated in each iteration *) + let liveness_ref : t ref = ref VBlockMap.empty in + (* Returns current live_info for specified block (default if not exists) *) + let get_current_info (bl : VBlockLabel.t) : live_info = + get_liveinfo !liveness_ref bl + in + (* Sets or updates live_info for a block *) + let set_current_info (bl : VBlockLabel.t) (info : live_info) = + liveness_ref := VBlockMap.add !liveness_ref bl info + in + let changed = ref true in + (* Scan all basic blocks within a function f *) + let cal_func (f_label : VFuncLabel.t) (bls : VBlockLabel.t list) = + let process_block (bl : VBlockLabel.t) = + let block = VProg.get_block vprog bl in + let old_info = get_current_info bl in + (**********************************************************) + (* PartA: Update block's liveOut & exitNextUse based on successors *) + (**********************************************************) + let b_liveOut = ref SlotSet.empty in + let b_exitNextUse = ref SlotMap.empty in + let successors = VBlock.get_successors block in + List.iter + (fun succ -> + let succ_info = get_current_info succ in + (* liveOut = union of successors' liveIn *) + b_liveOut := SlotSet.union !b_liveOut succ_info.liveIn; + (* exitNextUse = accumulate based on successors' entryNextUse *) + SlotMap.iter succ_info.exitNextUse (fun slot dist_in_succ -> + let new_dist = sat_add dist_in_succ 1 in + match SlotMap.find_opt !b_exitNextUse slot with + | Some old_d -> + (* if old_d is less than new_dist, keep old_d as it's better; otherwise update *) + if not_inf old_d && old_d <= new_dist + then () + else b_exitNextUse := SlotMap.add !b_exitNextUse slot new_dist + | None -> b_exitNextUse := SlotMap.add !b_exitNextUse slot new_dist)) + successors; + (**********************************************************) + (* PartB: Backward propagation to compute block's liveIn & entryNextUse *) + (**********************************************************) + + (* First add all sources from terminator instruction to liveIn *) + let b_liveIn = + ref (SlotSet.union !b_liveOut (Term.get_srcs block.term |> SlotSet.of_list)) + in + (* Initialize entry nextUse by copying from exitNextUse and increment all by 1 *) + let b_entryNextUse = ref (incr_all_values_by_one !b_exitNextUse) in + (* Set nextUse = 0 for terminator's sources *) + let term_srcs = Term.get_srcs block.term in + List.iter + (fun src -> b_entryNextUse := SlotMap.add !b_entryNextUse src 0) + term_srcs; + (* Initialize pressure counters *) + let maxP_I = ref 0 in + let maxP_F = ref 0 in + (* Update initial pressure based on current liveIn *) + let cnt_i, cnt_f = count_int_and_float !b_liveIn in + maxP_I := max !maxP_I cnt_i; + maxP_F := max !maxP_F cnt_f; + (***********************************************) + (* Traverse normal instructions (body) in reverse order *) + (***********************************************) + let body_insts = VBlock.get_body_insts block in + let reversed = List.rev body_insts in + List.iter + (fun inst -> + (* Increment all nextUse distances by 1 *) + b_entryNextUse := incr_all_values_by_one !b_entryNextUse; + (* Get dest/src from the current instruction *) + let dests = Inst.get_dests inst in + let srcs = Inst.get_srcs inst in + (* 1) remove dest from liveIn & entryNextUse *) + List.iter + (fun d -> + b_liveIn := SlotSet.remove !b_liveIn d; + b_entryNextUse := SlotMap.remove !b_entryNextUse d) + dests; + (* 2) add src to liveIn, set entryNextUse for these src to 0 *) + List.iter + (fun s -> + b_liveIn := SlotSet.add !b_liveIn s; + b_entryNextUse := SlotMap.add !b_entryNextUse s 0) + srcs; + (* 3) update maxPressure *) + let cnt_i, cnt_f = count_int_and_float !b_liveIn in + maxP_I := max !maxP_I cnt_i; + maxP_F := max !maxP_F cnt_f) + reversed; + (**********************************************************) + (* Record new b_liveIn, b_liveOut, b_entryNextUse, and maximum pressures *) + (**********************************************************) + let new_info = + { maxPressure_I = !maxP_I + ; maxPressure_F = !maxP_F + ; liveIn = !b_liveIn + ; liveOut = !b_liveOut + ; exitNextUse = !b_exitNextUse + } + in + (* + Compare new_info with old_info for changes + If changes detected, mark changed = true in outer scope + *) + let diff_liveIn = not (SlotSet.equal new_info.liveIn old_info.liveIn) in + let diff_liveOut = not (SlotSet.equal new_info.liveOut old_info.liveOut) in + let diff_exitNextUse = + not (SlotMap.equal ( = ) new_info.exitNextUse old_info.exitNextUse) + in + let diff_pressure_I = new_info.maxPressure_I <> old_info.maxPressure_I in + let diff_pressure_F = new_info.maxPressure_F <> old_info.maxPressure_F in + if + diff_liveIn + || diff_liveOut + || diff_exitNextUse + || diff_pressure_I + || diff_pressure_F + then ( + changed := true; + set_current_info bl new_info) + else + () + in + List.iter process_block bls + in + (* Outer loop: continue until no changes occur *) + while !changed do + changed := false; + VFuncMap.iter rpo (fun fun_lbl block_list -> cal_func fun_lbl block_list) + done; + (* final liveness analysis results for the entire program *) + !liveness_ref + ;; +end diff --git a/src/riscv_virtasm.ml b/src/riscv_virtasm.ml index 43c9f85..fb79060 100644 --- a/src/riscv_virtasm.ml +++ b/src/riscv_virtasm.ml @@ -236,6 +236,66 @@ module Inst = struct | Reload of stack_slot | FSpill of stack_fslot | FReload of stack_fslot + + let inst_map (inst : t) (rd : Slot.t -> Slot.t list) (rs : Slot.t -> Slot.t list) = + match inst with + | Add r_slot + | Sub r_slot + | And r_slot + | Or r_slot + | Xor r_slot + | Sll r_slot + | Srl r_slot + | Sra r_slot + | Mul r_slot + | Div r_slot + | Divu r_slot + | Rem r_slot + | Remu r_slot -> rd r_slot.rd @ rs r_slot.rs1 @ rs r_slot.rs2 + | Addi i_slot | Slli i_slot | Srli i_slot | Srai i_slot -> + rd i_slot.rd @ rs i_slot.rs1 + | Lw mem_slot | Ld mem_slot | Sw mem_slot | Sd mem_slot -> + rd mem_slot.rd @ rs mem_slot.base + | FaddD r_fslot | FsubD r_fslot | FmulD r_fslot | FdivD r_fslot -> + rd r_fslot.frd @ rs r_fslot.frs1 @ rs r_fslot.frs2 + | FmaddD triple_fslot + | FmsubD triple_fslot + | FnmaddD triple_fslot + | FnmsubD triple_fslot -> + rd triple_fslot.frd + @ rs triple_fslot.frs1 + @ rs triple_fslot.frs2 + @ rs triple_fslot.frs3 + | FeqD compare_fslot | FltD compare_fslot | FleD compare_fslot -> + rd compare_fslot.rd @ rs compare_fslot.frs1 @ rs compare_fslot.frs2 + | FcvtDW convert_fslot | FcvtDL convert_fslot -> + rd convert_fslot.frd @ rs convert_fslot.rs + | FcvtLD convert_slot | FcvtWDRtz convert_slot -> + rd convert_slot.rd @ rs convert_slot.frs + | FsqrtD assign_fslot | FabsD assign_fslot | FnegD assign_fslot | FmvD assign_fslot -> + rd assign_fslot.frd @ rs assign_fslot.frs + | Fld mem_fslot | Fsd mem_fslot -> rd mem_fslot.frd @ rs mem_fslot.base + | La assign_label -> [] + | Li assign_int64 -> [] + | Neg assign_slot | Mv assign_slot -> rd assign_slot.rd @ rs assign_slot.rs + | FmvDX assign_direct -> rd assign_direct.frd @ rs assign_direct.rs + | FmvDXZero single_fslot -> [] + | Call call_data -> + rd call_data.rd + @ List.concat_map rs call_data.args + @ List.concat_map rs call_data.fargs + | CallIndirect call_indirect -> + rd call_indirect.rd + @ rs call_indirect.fn + @ List.concat_map rs call_indirect.args + @ List.concat_map rs call_indirect.fargs + | Spill stack_slot | Reload stack_slot -> rd stack_slot.target @ rs stack_slot.origin + | FSpill stack_fslot | FReload stack_fslot -> + rd stack_fslot.target @ rs stack_fslot.origin + ;; + + let get_srcs (inst : t) : Slot.t list = inst_map inst (fun x -> []) (fun x -> [ x ]) + let get_dests (inst : t) : Slot.t list = inst_map inst (fun x -> [ x ]) (fun x -> []) end (* Vector alias*) @@ -297,12 +357,27 @@ module Term = struct | TailCall of call_data | TailCallIndirect of call_indirect | Ret of Slot.t (* Unit for no return*) + + let get_srcs (term : t) : Slot.t list = + match term with + | Beq branch_slot + | Bne branch_slot + | Blt branch_slot + | Bge branch_slot + | Bltu branch_slot + | Bgeu branch_slot -> [ branch_slot.rs1; branch_slot.rs2 ] + | Jalr jalr_label -> [ jalr_label.rs1 ] + | TailCall call_data -> call_data.args + | TailCallIndirect call_indirect -> call_indirect.args + | Ret _ -> [] + | J _ | Jal _ -> [] + ;; end (** VirtRvBlock*) module VBlock = struct type t = - { body : t Vec.t + { body : Inst.t Vec.t ; term : Term.t (* Single Terminator*) ; preds : VBlockLabel.t Vec.t (* Predecessors*) } @@ -323,6 +398,8 @@ module VBlock = struct (* Tail calls typically transfer control to another function, no direct successor *) | Ret _ -> [] (* Return terminates the current function, so no successor blocks *) ;; + + let get_body_insts (block : t) : Inst.t list = Vec.to_list block.body end (** VirtRvFunc*)