Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions src/modules/builtin/mv.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::mem::swap;


use crate::fragments;
use crate::modules::command::modifier::CommandModifier;
Expand Down Expand Up @@ -42,29 +42,31 @@ impl SyntaxModule<ParserMetadata> for Mv {

impl TypeCheckModule for Mv {
fn typecheck(&mut self, meta: &mut ParserMetadata) -> SyntaxResult {
self.source.typecheck(meta)?;
self.destination.typecheck(meta)?;
self.failure_handler.typecheck(meta)?;
self.modifier.use_modifiers(meta, |_, meta| {
self.source.typecheck(meta)?;
self.destination.typecheck(meta)?;
self.failure_handler.typecheck(meta)?;

let source_type = self.source.get_type();
if source_type != Type::Text {
let position = self.source.get_position();
return error_pos!(meta, position => {
message: "Builtin function `mv` can only be used with values of type Text",
comment: format!("Given type: {}, expected type: {}", source_type, Type::Text)
});
}
let source_type = self.source.get_type();
if source_type != Type::Text {
let position = self.source.get_position();
return error_pos!(meta, position => {
message: "Builtin function `mv` can only be used with values of type Text",
comment: format!("Given type: {}, expected type: {}", source_type, Type::Text)
});
}

let dest_type = self.destination.get_type();
if dest_type != Type::Text {
let position = self.destination.get_position();
return error_pos!(meta, position => {
message: "Builtin function `mv` can only be used with values of type Text",
comment: format!("Given type: {}, expected type: {}", dest_type, Type::Text)
});
}
let dest_type = self.destination.get_type();
if dest_type != Type::Text {
let position = self.destination.get_position();
return error_pos!(meta, position => {
message: "Builtin function `mv` can only be used with values of type Text",
comment: format!("Given type: {}, expected type: {}", dest_type, Type::Text)
});
}

Ok(())
Ok(())
})
}
}

Expand All @@ -73,10 +75,9 @@ impl TranslateModule for Mv {
let source = self.source.translate(meta);
let destination = self.destination.translate(meta);
let handler = self.failure_handler.translate(meta);
let mut is_silent = self.modifier.is_silent || meta.silenced;
swap(&mut is_silent, &mut meta.silenced);
let silent = meta.gen_silent().to_frag();
swap(&mut is_silent, &mut meta.silenced);
let silent = meta.with_silenced(self.modifier.is_silent || meta.silenced, |meta| {
meta.gen_silent().to_frag()
});
BlockFragment::new(vec![
fragments!("mv ", source, " ", destination, silent),
handler,
Expand Down
11 changes: 7 additions & 4 deletions src/modules/command/cmd.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::modules::types::{Type, Typed};

use crate::modules::condition::failure_handler::FailureHandler;
use crate::modules::expression::expr::Expr;
use crate::modules::expression::interpolated_region::{InterpolatedRegionType, parse_interpolated_region};
Expand Down Expand Up @@ -60,10 +61,12 @@ impl SyntaxModule<ParserMetadata> for Command {

impl TypeCheckModule for Command {
fn typecheck(&mut self, meta: &mut ParserMetadata) -> SyntaxResult {
for interp in self.interps.iter_mut() {
interp.typecheck(meta)?;
}
self.failure_handler.typecheck(meta)
self.modifier.use_modifiers(meta, |_, meta| {
for interp in self.interps.iter_mut() {
interp.typecheck(meta)?;
}
self.failure_handler.typecheck(meta)
})
}
}

Expand Down
16 changes: 6 additions & 10 deletions src/modules/command/modifier.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::mem::swap;
use amber_meta::ContextManager;
use heraclitus_compiler::prelude::*;
use crate::modules::prelude::*;
use crate::modules::block::Block;

#[derive(Debug, Clone)]
#[derive(Debug, Clone, ContextManager)]
pub struct CommandModifier {
pub block: Option<Box<Block>>,
#[context]
pub is_trust: bool,
pub is_silent: bool,
pub is_sudo: bool
Expand All @@ -24,15 +25,10 @@ impl CommandModifier {
pub fn use_modifiers<F>(
&mut self, meta: &mut ParserMetadata, context: F
) -> SyntaxResult where F: FnOnce(&mut Self, &mut ParserMetadata) -> SyntaxResult {
let mut is_trust_holder = self.is_trust;
if self.is_trust {
swap(&mut is_trust_holder, &mut meta.context.is_trust_ctx);
}
// The setter returns the old value
let old_trust = meta.context.set_is_trust_ctx(self.is_trust || meta.context.is_trust_ctx);
let result = context(self, meta);
// Swap back the value
if self.is_trust {
swap(&mut is_trust_holder, &mut meta.context.is_trust_ctx);
}
meta.context.set_is_trust_ctx(old_trust);
result
}

Expand Down
142 changes: 71 additions & 71 deletions src/modules/function/invocation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::mem::swap;


use heraclitus_compiler::prelude::*;
use crate::{fragments, raw_fragment};
Expand Down Expand Up @@ -100,65 +100,67 @@ impl SyntaxModule<ParserMetadata> for FunctionInvocation {

impl TypeCheckModule for FunctionInvocation {
fn typecheck(&mut self, meta: &mut ParserMetadata) -> SyntaxResult {
// Type-check all arguments first
for arg in &mut self.args {
arg.typecheck(meta)?;
}
self.modifier.use_modifiers(meta, |_, meta| {
// Type-check all arguments first
for arg in &mut self.args {
arg.typecheck(meta)?;
}

// Look up the function declaration (this requires typecheck phase context)
self.id = handle_function_reference(meta, self.name_tok.clone(), &self.name)?;

let function_unit = meta.get_fun_declaration(&self.name).unwrap().clone();
let expected_arg_count = function_unit.args.len();
let actual_arg_count = self.args.len();
let optional_count = function_unit.args.iter().filter(|arg| arg.optional.is_some()).count();

// Handle missing arguments by filling with optional defaults
if actual_arg_count < expected_arg_count {
// Check if we can compensate with optional arguments stored in fun_unit
if actual_arg_count >= expected_arg_count - optional_count {
let missing = expected_arg_count - actual_arg_count;
let provided_optional = optional_count - missing;
let optionals: Vec<_> = function_unit.args.iter().filter_map(|arg| arg.optional.as_ref()).collect();
for exp in optionals.iter().skip(provided_optional){
self.args.push((*exp).clone());
// Look up the function declaration (this requires typecheck phase context)
self.id = handle_function_reference(meta, self.name_tok.clone(), &self.name)?;

let function_unit = meta.get_fun_declaration(&self.name).unwrap().clone();
let expected_arg_count = function_unit.args.len();
let actual_arg_count = self.args.len();
let optional_count = function_unit.args.iter().filter(|arg| arg.optional.is_some()).count();

// Handle missing arguments by filling with optional defaults
if actual_arg_count < expected_arg_count {
// Check if we can compensate with optional arguments stored in fun_unit
if actual_arg_count >= expected_arg_count - optional_count {
let missing = expected_arg_count - actual_arg_count;
let provided_optional = optional_count - missing;
let optionals: Vec<_> = function_unit.args.iter().filter_map(|arg| arg.optional.as_ref()).collect();
for exp in optionals.iter().skip(provided_optional){
self.args.push((*exp).clone());
}
}
}
}

// Validate arguments and get function variant
let types = self.args.iter().map(Expr::get_type).collect::<Vec<Type>>();
let var_refs = self.args.iter().map(is_ref).collect::<Vec<bool>>();
self.refs = function_unit.args.iter().map(|arg| arg.is_ref).collect();
(self.kind, self.variant_id) = handle_function_parameters(meta, self.id, function_unit.clone(), &types, &var_refs, self.name_tok.clone())?;

// Mark variables passed as reference as modified and used
for (arg, is_ref) in izip!(self.args.iter(), self.refs.iter()) {
if *is_ref {
if let Some(ExprType::VariableGet(var)) = &arg.value {
meta.mark_var_modified(&var.name);
// Validate arguments and get function variant
let types = self.args.iter().map(Expr::get_type).collect::<Vec<Type>>();
let var_refs = self.args.iter().map(is_ref).collect::<Vec<bool>>();
self.refs = function_unit.args.iter().map(|arg| arg.is_ref).collect();
(self.kind, self.variant_id) = handle_function_parameters(meta, self.id, function_unit.clone(), &types, &var_refs, self.name_tok.clone())?;

// Mark variables passed as reference as modified and used
for (arg, is_ref) in izip!(self.args.iter(), self.refs.iter()) {
if *is_ref {
if let Some(ExprType::VariableGet(var)) = &arg.value {
meta.mark_var_modified(&var.name);
}
}
}
}

// Handle failable function logic
self.is_failable = function_unit.is_failable;
if self.is_failable {
if !self.failure_handler.is_parsed && !meta.context.is_trust_ctx {
return error!(meta, self.name_tok.clone() => {
message: format!("Function '{}' can potentially fail but is left unhandled.", self.name),
comment: "You can use '?' to propagate failure, 'failed' block to handle failure, 'succeeded' block to handle success, or 'exited' block to handle both"
});
// Handle failable function logic
self.is_failable = function_unit.is_failable;
if self.is_failable {
if !self.failure_handler.is_parsed && !meta.context.is_trust_ctx {
return error!(meta, self.name_tok.clone() => {
message: format!("Function '{}' can potentially fail but is left unhandled.", self.name),
comment: "You can use '?' to propagate failure, 'failed' block to handle failure, 'succeeded' block to handle success, 'exited' block to handle both"
});
}
self.failure_handler.typecheck(meta)?;
} else if self.failure_handler.is_parsed && !meta.context.is_trust_ctx {
let message = Message::new_warn_at_token(meta, self.name_tok.clone())
.message(format!("Function '{}' cannot fail", &self.name))
.comment("You can remove the failure handler block or '?' at the end");
meta.add_message(message);
}
self.failure_handler.typecheck(meta)?;
} else if self.failure_handler.is_parsed && !meta.context.is_trust_ctx {
let message = Message::new_warn_at_token(meta, self.name_tok.clone())
.message(format!("Function '{}' cannot fail", &self.name))
.comment("You can remove the failure handler block or '?' at the end");
meta.add_message(message);
}

Ok(())
Ok(())
})
}
}

Expand All @@ -167,26 +169,24 @@ impl TranslateModule for FunctionInvocation {
// Get the variable prefix based on function name casing
let prefix = meta.gen_variable_prefix(&self.name);
let name = raw_fragment!("{}{}__{}_v{}", prefix, self.name, self.id, self.variant_id);
let mut is_silent = self.modifier.is_silent || meta.silenced;
swap(&mut is_silent, &mut meta.silenced);
let silent = meta.gen_silent().to_frag();

let args = izip!(self.args.iter(), self.refs.iter()).map(| (arg, is_ref) | match arg.translate(meta) {
FragmentKind::VarExpr(var) if *is_ref => var.with_render_type(VarRenderType::BashRef).to_frag(),
FragmentKind::VarExpr(var) if var.kind.is_array() && var.index.is_some() => {
let id = meta.gen_value_id();
let temp_name = format!("{}_{id}", var.get_index_typename());
let stmt = VarStmtFragment::new(&temp_name, var.kind.clone(), FragmentKind::VarExpr(var.clone()));
let temp_var = meta.push_ephemeral_variable(stmt);
fragments!(temp_var.with_render_type(VarRenderType::BashRef).to_frag().with_quotes(false), "[@]")
},
FragmentKind::VarExpr(var) if var.kind.is_array() => fragments!(var.with_render_type(VarRenderType::BashRef).to_frag().with_quotes(false), "[@]"),
_ if *is_ref => panic!("Reference value accepts only variables"),
var => var
}).collect::<Vec<FragmentKind>>();
let args = ListFragment::new(args).with_spaces().to_frag();
meta.stmt_queue.push_back(fragments!(name, " ", args, silent));
swap(&mut is_silent, &mut meta.silenced);
meta.with_silenced(self.modifier.is_silent || meta.silenced, |meta| {
let silent = meta.gen_silent().to_frag();
let args = izip!(self.args.iter(), self.refs.iter()).map(| (arg, is_ref) | match arg.translate(meta) {
FragmentKind::VarExpr(var) if *is_ref => var.with_render_type(VarRenderType::BashRef).to_frag(),
FragmentKind::VarExpr(var) if var.kind.is_array() && var.index.is_some() => {
let id = meta.gen_value_id();
let temp_name = format!("{}_{id}", var.get_index_typename());
let stmt = VarStmtFragment::new(&temp_name, var.kind.clone(), FragmentKind::VarExpr(var.clone()));
let temp_var = meta.push_ephemeral_variable(stmt);
fragments!(temp_var.with_render_type(VarRenderType::BashRef).to_frag().with_quotes(false), "[@]")
},
FragmentKind::VarExpr(var) if var.kind.is_array() => fragments!(var.with_render_type(VarRenderType::BashRef).to_frag().with_quotes(false), "[@]"),
_ if *is_ref => panic!("Reference value accepts only variables"),
var => var
}).collect::<Vec<FragmentKind>>();
let args = ListFragment::new(args).with_spaces().to_frag();
meta.stmt_queue.push_back(fragments!(name.clone(), " ", args, silent));
});
if self.is_failable && self.failure_handler.is_parsed {
let handler = self.failure_handler.translate(meta);
meta.stmt_queue.push_back(handler);
Expand Down
14 changes: 14 additions & 0 deletions src/tests/warning/nested_trust.ab
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Output
// Done

fun foo(arg) {
if arg == 0: fail 1
return arg
}

fun bar() {
return 42
}

trust foo(bar())
echo "Done"
10 changes: 10 additions & 0 deletions src/tests/warning/unnecessary_handler.ab
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Output
// Function 'foo' cannot fail

fun foo(): Num {
return 1
}

foo() failed {
echo "This should warn"
}
1 change: 1 addition & 0 deletions src/utils/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ pub struct Context {
/// Determines if the context is in the main block
pub is_main_ctx: bool,
/// Determines if the context is in a trust block
#[context]
pub is_trust_ctx: bool,
/// This is a list of ids of all the public functions in the file
pub pub_funs: Vec<FunctionDecl>,
Expand Down