Skip to content

Commit a76f05b

Browse files
authored
Added const-folding for bounded_int_trim_* functions. (#8747)
1 parent cea729e commit a76f05b

File tree

3 files changed

+369
-3
lines changed

3 files changed

+369
-3
lines changed

crates/cairo-lang-lowering/src/optimizations/const_folding.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,42 @@ impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
11021102
let output = info.arms[arm_idx].var_ids[0];
11031103
statements.push(self.propagate_const_and_get_statement(value.clone(), output));
11041104
Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1105+
} else if id == self.bounded_int_trim_min {
1106+
let input_var = info.inputs[0].var_id;
1107+
let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1108+
return None;
1109+
};
1110+
let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1111+
range.min == *value
1112+
} else {
1113+
corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1114+
};
1115+
let arm_idx = if is_trimmed {
1116+
0
1117+
} else {
1118+
let output = info.arms[1].var_ids[0];
1119+
statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1120+
1
1121+
};
1122+
Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1123+
} else if id == self.bounded_int_trim_max {
1124+
let input_var = info.inputs[0].var_id;
1125+
let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1126+
return None;
1127+
};
1128+
let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1129+
range.max == *value
1130+
} else {
1131+
corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1132+
};
1133+
let arm_idx = if is_trimmed {
1134+
0
1135+
} else {
1136+
let output = info.arms[1].var_ids[0];
1137+
statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1138+
1
1139+
};
1140+
Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
11051141
} else if id == self.array_get {
11061142
let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
11071143
if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
@@ -1414,6 +1450,10 @@ pub struct ConstFoldingLibfuncInfo<'db> {
14141450
bounded_int_sub: ExternFunctionId<'db>,
14151451
/// The `bounded_int_constrain` libfunc.
14161452
bounded_int_constrain: ExternFunctionId<'db>,
1453+
/// The `bounded_int_trim_min` libfunc.
1454+
bounded_int_trim_min: ExternFunctionId<'db>,
1455+
/// The `bounded_int_trim_max` libfunc.
1456+
bounded_int_trim_max: ExternFunctionId<'db>,
14171457
/// The `array_get` libfunc.
14181458
array_get: ExternFunctionId<'db>,
14191459
/// The `array_snapshot_pop_front` libfunc.
@@ -1556,6 +1596,8 @@ impl<'db> ConstFoldingLibfuncInfo<'db> {
15561596
bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
15571597
bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
15581598
bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1599+
bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1600+
bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
15591601
array_get: array_module.extern_function_id("array_get"),
15601602
array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
15611603
array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),

crates/cairo-lang-lowering/src/optimizations/const_folding_test.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,18 @@ fn test_match_optimizer(
3737
let mut before = db
3838
.lowered_body(function_id, LoweringStage::PreOptimizations)
3939
.unwrap_or_else(|_| {
40-
let semantic_diags = db.module_semantic_diagnostics(test_function.module_id).unwrap();
41-
let lowering_diags = db.module_lowering_diagnostics(test_function.module_id);
40+
let semantic_diags = db
41+
.module_semantic_diagnostics(test_function.module_id)
42+
.unwrap_or_default()
43+
.format(db);
44+
let lowering_diags = db
45+
.module_lowering_diagnostics(test_function.module_id)
46+
.unwrap_or_default()
47+
.format(db);
4248

4349
panic!(
4450
"Failed to get lowered body for function {function_id:?}.\nSemantic diagnostics: \
45-
{semantic_diags:?}\nLowering diagnostics: {lowering_diags:?}",
51+
{semantic_diags}\nLowering diagnostics: {lowering_diags}",
4652
)
4753
})
4854
.clone();

0 commit comments

Comments
 (0)