Skip to content

Commit

Permalink
Extract ModuleHelper from const folding. (#7099)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Jan 16, 2025
1 parent 57c7421 commit bf87e46
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 63 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/cairo-lang-lowering/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ num-bigint = { workspace = true, default-features = true }
num-integer = { workspace = true, default-features = true }
num-traits = { workspace = true, default-features = true }
salsa.workspace = true
smol_str.workspace = true

[dev-dependencies]
cairo-lang-plugins = { path = "../cairo-lang-plugins" }
Expand Down
75 changes: 18 additions & 57 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ mod test;

use std::sync::Arc;

use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId};
use cairo_lang_defs::ids::{ExternFunctionId, ModuleId};
use cairo_lang_semantic::helper::ModuleHelper;
use cairo_lang_semantic::items::constant::ConstValue;
use cairo_lang_semantic::items::imp::ImplLookupContext;
use cairo_lang_semantic::{GenericArgumentId, MatchArmSelector, TypeId, corelib};
Expand All @@ -17,10 +18,9 @@ use itertools::{chain, zip_eq};
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::Zero;
use smol_str::SmolStr;

use crate::db::LoweringGroup;
use crate::ids::{FunctionId, FunctionLongId};
use crate::ids::{FunctionId, SemanticFunctionIdEx};
use crate::{
BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
Expand Down Expand Up @@ -276,10 +276,14 @@ impl ConstFoldingContext<'_> {
let input_var = stmt.inputs[0].var_id;
if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
stmt.inputs.clear();
stmt.function = ModuleHelper { db: self.db, id: self.storage_access_module }
.function_id("storage_base_address_const", vec![GenericArgumentId::Constant(
ConstValue::Int(val.clone(), *ty).intern(self.db),
)]);
stmt.function =
ModuleHelper { db: self.db.upcast(), id: self.storage_access_module }
.function_id("storage_base_address_const", vec![
GenericArgumentId::Constant(
ConstValue::Int(val.clone(), *ty).intern(self.db),
),
])
.lowered(self.db);
}
None
} else if id == self.into_box {
Expand Down Expand Up @@ -479,8 +483,9 @@ impl ConstFoldingContext<'_> {
let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
info.inputs.truncate(1);
info.function = ModuleHelper { db: self.db, id: self.array_module }
.function_id("array_snapshot_pop_front", generic_args);
info.function = ModuleHelper { db: self.db.upcast(), id: self.array_module }
.function_id("array_snapshot_pop_front", generic_args)
.lowered(self.db);
success.var_ids.insert(0, unused_arr_output0);
failure.var_ids.insert(0, unused_arr_output1);
}
Expand Down Expand Up @@ -539,52 +544,6 @@ pub fn priv_const_folding_info(
Arc::new(ConstFoldingLibfuncInfo::new(db))
}

/// Helper for getting functions in the corelib.
struct ModuleHelper<'a> {
/// The db.
db: &'a dyn LoweringGroup,
/// The current module id.
id: ModuleId,
}
impl<'a> ModuleHelper<'a> {
/// Returns a helper for the core module.
fn core(db: &'a dyn LoweringGroup) -> Self {
Self { db, id: corelib::core_module(db.upcast()) }
}
/// Returns a helper for a submodule named `name` of the current module.
fn submodule(&self, name: &str) -> Self {
let id = corelib::get_submodule(self.db.upcast(), self.id, name).unwrap_or_else(|| {
panic!("`{name}` missing in `{}`.", self.id.full_path(self.db.upcast()))
});
Self { db: self.db, id }
}
/// Returns the id of an extern function named `name` in the current module.
fn extern_function_id(&self, name: impl Into<SmolStr>) -> ExternFunctionId {
let name = name.into();
let Ok(Some(ModuleItemId::ExternFunction(id))) =
self.db.module_item_by_name(self.id, name.clone())
else {
panic!("`{}` not found in `{}`.", name, self.id.full_path(self.db.upcast()));
};
id
}
/// Returns the id of a function named `name` in the current module, with the given
/// `generic_args`.
fn function_id(
&self,
name: impl Into<SmolStr>,
generic_args: Vec<GenericArgumentId>,
) -> FunctionId {
FunctionLongId::Semantic(corelib::get_function_id(
self.db.upcast(),
self.id,
name.into(),
generic_args,
))
.intern(self.db)
}
}

/// Holds static information about libfuncs required for the optimization.
#[derive(Debug, PartialEq, Eq)]
pub struct ConstFoldingLibfuncInfo {
Expand Down Expand Up @@ -633,7 +592,7 @@ pub struct ConstFoldingLibfuncInfo {
}
impl ConstFoldingLibfuncInfo {
fn new(db: &dyn LoweringGroup) -> Self {
let core = ModuleHelper::core(db);
let core = ModuleHelper::core(db.upcast());
let felt_sub = core.extern_function_id("felt252_sub");
let box_module = core.submodule("box");
let into_box = box_module.extern_function_id("into_box");
Expand Down Expand Up @@ -707,7 +666,9 @@ impl ConstFoldingLibfuncInfo {
let info = TypeInfo {
min,
max,
is_zero: integer_module.function_id(format!("{ty}_is_zero"), vec![]),
is_zero: integer_module
.function_id(format!("{ty}_is_zero"), vec![])
.lowered(db),
};
(corelib::get_core_ty_by_name(db.upcast(), ty.into(), vec![]), info)
}),
Expand Down
45 changes: 45 additions & 0 deletions crates/cairo-lang-semantic/src/helper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId};
use smol_str::SmolStr;

use crate::db::SemanticGroup;
use crate::{FunctionId, GenericArgumentId, corelib};

/// Helper for getting functions in the corelib.
pub struct ModuleHelper<'a> {
/// The db.
pub db: &'a dyn SemanticGroup,
/// The current module id.
pub id: ModuleId,
}
impl<'a> ModuleHelper<'a> {
/// Returns a helper for the core module.
pub fn core(db: &'a dyn SemanticGroup) -> Self {
Self { db, id: db.core_module() }
}
/// Returns a helper for a submodule named `name` of the current module.
pub fn submodule(&self, name: &str) -> Self {
let id = corelib::get_submodule(self.db, self.id, name).unwrap_or_else(|| {
panic!("`{name}` missing in `{}`.", self.id.full_path(self.db.upcast()))
});
Self { db: self.db, id }
}
/// Returns the id of an extern function named `name` in the current module.
pub fn extern_function_id(&self, name: impl Into<SmolStr>) -> ExternFunctionId {
let name = name.into();
let Ok(Some(ModuleItemId::ExternFunction(id))) =
self.db.module_item_by_name(self.id, name.clone())
else {
panic!("`{}` not found in `{}`.", name, self.id.full_path(self.db.upcast()));
};
id
}
/// Returns the id of a function named `name` in the current module, with the given
/// `generic_args`.
pub fn function_id(
&self,
name: impl Into<SmolStr>,
generic_args: Vec<GenericArgumentId>,
) -> FunctionId {
corelib::get_function_id(self.db, self.id, name.into(), generic_args)
}
}
9 changes: 5 additions & 4 deletions crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ use smol_str::SmolStr;
use super::functions::{GenericFunctionId, GenericFunctionWithBodyId};
use super::imp::{ImplId, ImplLongId};
use crate::corelib::{
CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant,
get_core_function_id, get_core_trait, get_core_ty_by_name, true_variant,
try_extract_nz_wrapped_type, unit_ty, validate_literal,
CoreTraitContext, LiteralError, core_box_ty, core_nonzero_ty, false_variant, get_core_trait,
get_core_ty_by_name, true_variant, try_extract_nz_wrapped_type, unit_ty, validate_literal,
};
use crate::db::SemanticGroup;
use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
Expand All @@ -37,6 +36,7 @@ use crate::expr::compute::{
};
use crate::expr::inference::conform::InferenceConform;
use crate::expr::inference::{ConstVar, InferenceId};
use crate::helper::ModuleHelper;
use crate::literals::try_extract_minus_literal;
use crate::resolve::{Resolver, ResolverData};
use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
Expand Down Expand Up @@ -1232,6 +1232,7 @@ impl ConstCalcInfo {
db.trait_function_by_name(trait_id, name.into()).unwrap().unwrap()
};
let unit_const = ConstValue::Struct(vec![], unit_ty(db));
let core = ModuleHelper::core(db);
Self {
const_traits: [
neg_trait,
Expand Down Expand Up @@ -1270,7 +1271,7 @@ impl ConstCalcInfo {
true_const: ConstValue::Enum(true_variant(db), unit_const.clone().into()),
false_const: ConstValue::Enum(false_variant(db), unit_const.clone().into()),
unit_const,
panic_with_felt252: get_core_function_id(db, "panic_with_felt252".into(), vec![]),
panic_with_felt252: core.function_id("panic_with_felt252", vec![]),
}
}
}
1 change: 1 addition & 0 deletions crates/cairo-lang-semantic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod corelib;
pub mod db;
pub mod diagnostic;
pub mod expr;
pub mod helper;
pub mod inline_macros;
pub mod items;
pub mod literals;
Expand Down

0 comments on commit bf87e46

Please sign in to comment.