Skip to content

Commit

Permalink
Fixed generated functions in trait-fns. (#7081)
Browse files Browse the repository at this point in the history
feat(corelib): Iterator::enumerate (#7048)

fix: Fix handling of --skip-first argument Update release_crates.sh (#7087)

chore: orthographic correction in file if_else (#7088)

prevents closure parameters from being declared as refrences (#7078)

Refactored bounded_int_trim. (#7062)

Added const for starknet types. (#6961)

feat(corelib): Iterator::fold (#7084)

feat(corelib): Iterator::advance_by (#7059)

fix(corelib): Add the #[test] annotation to enumerate test (#7098)

feat(corelib): storage vectors iterators (#6941)

Extract ModuleHelper from const folding. (#7099)

Added support for basic `Into`s in consts. (#7100)

Removed taking value for `validate_literal`. (#7101)

added closure params to semantic defs in lowering (#7085)

Added support for `downcast` in constant context. (#7102)

fix(corelib): Add the #[test] annotation to enumerate test (#7098)
  • Loading branch information
orizi authored and dean-starkware committed Jan 16, 2025
1 parent e580d6c commit 654355d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 24 deletions.
87 changes: 69 additions & 18 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId};
use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr};
use crate::items::enm::SemanticEnumEx;
use crate::items::feature_kind::extract_item_feature_config;
use crate::items::functions::function_signature_params;
use crate::items::functions::{concrete_function_closure_params, function_signature_params};
use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self};
use crate::items::modifiers::compute_mutability;
use crate::items::us::get_use_path_segments;
Expand All @@ -84,8 +84,8 @@ use crate::types::{
};
use crate::usage::Usages;
use crate::{
ConcreteEnumId, GenericArgumentId, GenericParam, LocalItem, Member, Mutability, Parameter,
PatternStringLiteral, PatternStruct, Signature, StatementItemKind,
ConcreteEnumId, ConcreteFunction, GenericArgumentId, GenericParam, LocalItem, Member,
Mutability, Parameter, PatternStringLiteral, PatternStruct, Signature, StatementItemKind,
};

/// Expression with its id.
Expand Down Expand Up @@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic(
ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr),
ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr),
ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr),
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr),
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None),
}
}

Expand Down Expand Up @@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic(
let mut arg_types = vec![];
for arg_syntax in args_iter {
let stable_ptr = arg_syntax.stable_ptr();
let arg = compute_named_argument_clause(ctx, arg_syntax);
let arg = compute_named_argument_clause(ctx, arg_syntax, None);
if arg.2 != Mutability::Immutable {
return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument));
}
Expand Down Expand Up @@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic(
let named_args: Vec<_> = args_syntax
.elements(syntax_db)
.into_iter()
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax))
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None))
.collect();
if named_args.len() != 1 {
return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments {
Expand Down Expand Up @@ -979,16 +979,22 @@ fn compute_expr_function_call_semantic(
let mut args_iter = args_syntax.elements(syntax_db).into_iter();
// Normal parameters
let mut named_args = vec![];
for _ in function_parameter_types(ctx, function)? {
let ConcreteFunction { .. } = function.lookup_intern(db).function;
let closure_params = concrete_function_closure_params(db, function)?;
for ty in function_parameter_types(ctx, function)? {
let Some(arg_syntax) = args_iter.next() else {
continue;
};
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(
ctx,
arg_syntax,
closure_params.get(&ty).cloned(),
));
}

// Maybe coupon
if let Some(arg_syntax) = args_iter.next() {
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
}

expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into())
Expand All @@ -1006,6 +1012,7 @@ fn compute_expr_function_call_semantic(
pub fn compute_named_argument_clause(
ctx: &mut ComputationContext<'_>,
arg_syntax: ast::Arg,
closure_param_types: Option<TypeId>,
) -> NamedArg {
let syntax_db = ctx.db.upcast();

Expand All @@ -1018,12 +1025,38 @@ pub fn compute_named_argument_clause(
let arg_clause = arg_syntax.arg_clause(syntax_db);
let (expr, arg_name_identifier) = match arg_clause {
ast::ArgClause::Unnamed(arg_unnamed) => {
(compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None)
let arg_expr = arg_unnamed.value(syntax_db);
if let ast::Expr::Closure(expr_closure) = arg_expr {
let expr = compute_expr_closure_semantic(ctx, &expr_closure, closure_param_types);
let expr = wrap_maybe_with_missing(
ctx,
expr,
ast::ExprPtr::from(expr_closure.stable_ptr()),
);
let id = ctx.arenas.exprs.alloc(expr.clone());
(ExprAndId { expr, id }, None)
} else {
(compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None)
}
}
ast::ArgClause::Named(arg_named) => {
let arg_expr = arg_named.value(syntax_db);
if let ast::Expr::Closure(expr_closure) = arg_expr {
let expr = compute_expr_closure_semantic(ctx, &expr_closure, closure_param_types);
let expr = wrap_maybe_with_missing(
ctx,
expr,
ast::ExprPtr::from(expr_closure.stable_ptr()),
);
let id = ctx.arenas.exprs.alloc(expr.clone());
(ExprAndId { expr, id }, None)
} else {
(
compute_expr_semantic(ctx, &arg_named.value(syntax_db)),
Some(arg_named.name(syntax_db)),
)
}
}
ast::ArgClause::Named(arg_named) => (
compute_expr_semantic(ctx, &arg_named.value(syntax_db)),
Some(arg_named.name(syntax_db)),
),
ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => {
let name_expr = arg_field_init_shorthand.name(syntax_db);
let stable_ptr: ast::ExprPtr = name_expr.stable_ptr().into();
Expand Down Expand Up @@ -1645,6 +1678,7 @@ fn compute_loop_body_semantic(
fn compute_expr_closure_semantic(
ctx: &mut ComputationContext<'_>,
syntax: &ast::ExprClosure,
param_types: Option<TypeId>,
) -> Maybe<Expr> {
ctx.are_closures_in_context = true;
let syntax_db = ctx.db.upcast();
Expand All @@ -1663,6 +1697,18 @@ fn compute_expr_closure_semantic(
} else {
vec![]
};
let closure_type =
TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db);
if let Some(param_types) = param_types {
if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types)
{
new_ctx.resolver.inference().consume_error_without_reporting(err_set);
}
}

params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| {
new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam);
});

params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| {
new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam);
Expand Down Expand Up @@ -2834,16 +2880,22 @@ fn method_call_expr(
// Self argument.
let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)];
// Other arguments.
for _ in function_parameter_types(ctx, function_id)?.skip(1) {
let ConcreteFunction { .. } = function_id.lookup_intern(ctx.db).function;
let closure_params = concrete_function_closure_params(ctx.db, function_id)?;
for ty in function_parameter_types(ctx, function_id)?.skip(1) {
let Some(arg_syntax) = args_iter.next() else {
break;
};
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(
ctx,
arg_syntax,
closure_params.get(&ty).cloned(),
));
}

// Maybe coupon
if let Some(arg_syntax) = args_iter.next() {
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
}

expr_function_call(ctx, function_id, named_args, &expr, stable_ptr)
Expand Down Expand Up @@ -3263,7 +3315,6 @@ fn expr_function_call(

// Check argument names and types.
check_named_arguments(&named_args, &signature, ctx)?;

let mut args = Vec::new();
for (NamedArg(arg, _name, mutability), param) in
named_args.into_iter().zip(signature.params.iter())
Expand Down
20 changes: 20 additions & 0 deletions crates/cairo-lang-semantic/src/expr/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,23 @@ error: Closure parameters cannot be references
--> lib.cairo:2:14
let _ = |ref a| {
^^^^^

//! > ==========================================================================

//! > Passing closures as args with less explicit typing.

//! > test_runner_name
test_function_diagnostics(expect_diagnostics: false)

//! > function
fn foo() -> Option<u32> {
let x: Option<Array<i32>> = Option::Some(array![1, 2, 3]);
x.map(|x| x.len())
}

//! > function_name
foo

//! > module_code

//! > expected_diagnostics
58 changes: 52 additions & 6 deletions crates/cairo-lang-semantic/src/items/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::fmt::Debug;
use std::sync::Arc;

use cairo_lang_debug::DebugWithDb;
Expand All @@ -14,6 +13,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_syntax as syntax;
use cairo_lang_syntax::attribute::structured::Attribute;
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{
Intern, LookupIntern, OptionFrom, define_short_id, require, try_extract_matches,
};
Expand All @@ -27,16 +27,16 @@ use super::generics::{fmt_generic_args, generic_params_to_args};
use super::imp::{ImplId, ImplLongId};
use super::modifiers;
use super::trt::ConcreteTraitGenericFunctionId;
use crate::corelib::{panic_destruct_trait_fn, unit_ty};
use crate::corelib::{fn_traits, panic_destruct_trait_fn, unit_ty};
use crate::db::SemanticGroup;
use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
use crate::expr::compute::Environment;
use crate::resolve::{Resolver, ResolverData};
use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
use crate::types::resolve_type;
use crate::{
ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericParam, SemanticDiagnostic,
TypeId, semantic, semantic_object_for_id,
ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericArgumentId, GenericParam,
SemanticDiagnostic, TypeId, semantic, semantic_object_for_id,
};

/// A generic function of an impl.
Expand Down Expand Up @@ -124,6 +124,36 @@ impl GenericFunctionId {
}
}
}

pub fn get_closure_params(
&self,
db: &dyn SemanticGroup,
) -> Maybe<OrderedHashMap<TypeId, TypeId>> {
let mut closure_params_map = OrderedHashMap::default();
let generic_params = self.generic_params(db)?;

for param in generic_params {
if let GenericParam::Impl(generic_param_impl) = param {
let trait_id = generic_param_impl.concrete_trait?.trait_id(db);

if fn_traits(db).contains(&trait_id) {
if let Ok(concrete_trait) = generic_param_impl.concrete_trait {
let [
GenericArgumentId::Type(closure_type),
GenericArgumentId::Type(params_type),
] = *concrete_trait.generic_args(db)
else {
unreachable!()
};

closure_params_map.insert(closure_type, params_type);
}
}
}
}
Ok(closure_params_map)
}

pub fn generic_signature(&self, db: &dyn SemanticGroup) -> Maybe<Signature> {
match *self {
GenericFunctionId::Free(id) => db.free_function_signature(id),
Expand All @@ -146,8 +176,11 @@ impl GenericFunctionId {
GenericFunctionId::Extern(id) => db.extern_function_declaration_generic_params(id),
GenericFunctionId::Impl(id) => {
let concrete_trait_id = db.impl_concrete_trait(id.impl_id)?;
let id = ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function);
db.concrete_trait_function_generic_params(id)
let concrete_id =
ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function);
let substitution = GenericSubstitution::from_impl(id.impl_id);
let mut rewriter = SubstitutionRewriter { db, substitution: &substitution };
rewriter.rewrite(db.concrete_trait_function_generic_params(concrete_id)?)
}
GenericFunctionId::Trait(id) => db.concrete_trait_function_generic_params(id),
}
Expand Down Expand Up @@ -860,6 +893,19 @@ pub fn concrete_function_signature(
SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_signature)
}

/// Query implementation of [crate::db::SemanticGroup::concrete_function_closure_params].
pub fn concrete_function_closure_params(
db: &dyn SemanticGroup,
function_id: FunctionId,
) -> Maybe<OrderedHashMap<semantic::TypeId, semantic::TypeId>> {
let ConcreteFunction { generic_function, generic_args, .. } =
function_id.lookup_intern(db).function;
let generic_params = generic_function.generic_params(db)?;
let generic_closure_params = generic_function.get_closure_params(db)?;
let substitution = GenericSubstitution::new(&generic_params, &generic_args);
SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params)
}

/// For a given list of AST parameters, returns the list of semantic parameters along with the
/// corresponding environment.
fn update_env_with_ast_params(
Expand Down

0 comments on commit 654355d

Please sign in to comment.