diff --git a/src/commands/build.rs b/src/commands/build.rs index 6f0d420..f6261bd 100644 --- a/src/commands/build.rs +++ b/src/commands/build.rs @@ -81,9 +81,120 @@ pub fn build(args: BuildArgs) -> Result<()> { // Create necessary directories create_dir_all(deploy)?; + + fn is_word_char(c: char) -> bool { + c.is_alphanumeric() || c == '_' + } + + fn replace_whole_word(haystack: &str, label: &str, replacement: &str) -> String { + let mut result = String::with_capacity(haystack.len()); + let mut i = 0; + let label_len = label.len(); + while i <= haystack.len().saturating_sub(label_len) { + if haystack[i..].starts_with(label) { + let start = i; + let end = i + label_len; + let before_ok = + start == 0 || !is_word_char(haystack[..start].chars().last().unwrap()); + let after_ok = + end >= haystack.len() || !is_word_char(haystack[end..].chars().next().unwrap()); + if before_ok && after_ok { + result.push_str(replacement); + i = end; + continue; + } + } + let ch = haystack[i..].chars().next().unwrap(); + result.push(ch); + i += ch.len_utf8(); + } + result.push_str(&haystack[i..]); + result + } + + fn prefix_data_labels(content: &str, prefix: &str) -> String { + let mut data_labels: Vec = Vec::new(); + let mut in_data_section = false; + + for line in content.lines() { + let trimmed = line.trim(); + + if trimmed.starts_with(".text") { + in_data_section = false; + } else if trimmed.starts_with(".rodata") || trimmed.starts_with(".data") { + in_data_section = true; + } + + if in_data_section && let Some(colon_pos) = trimmed.find(':') { + let before_colon = trimmed[..colon_pos].trim(); + if !before_colon.is_empty() + && before_colon + .chars() + .all(|c| c.is_alphanumeric() || c == '_') + && !before_colon.starts_with('.') + { + data_labels.push(before_colon.to_string()); + } + } + } + + data_labels.sort_by_key(|b| std::cmp::Reverse(b.len())); + data_labels.dedup(); + + let mut result = content.to_string(); + for label in data_labels { + let replacement = format!("{}{}", prefix, label); + result = replace_whole_word(&result, &label, &replacement); + } + result + } + + fn expand_includes(source: &str, base_dir: &Path, is_included: bool) -> Result { + let mut result = String::new(); + for line in source.lines() { + let trimmed = line.trim(); + if let Some(rest) = trimmed.strip_prefix(".include") { + let rest = rest.trim(); + if let Some(path) = rest.strip_prefix('"').and_then(|s| s.strip_suffix('"')) { + let include_path = base_dir.join(path); + let included = fs::read_to_string(&include_path).map_err(|e| { + Error::msg(format!("Failed to read include '{}': {}", path, e)) + })?; + let include_base = include_path.parent().unwrap_or(base_dir); + let expanded = expand_includes(&included, include_base, true)?; + let path_stem = path.strip_suffix(".s").unwrap_or(path); + let prefix = path_stem.replace('/', "_") + "___"; + result.push_str(&prefix_data_labels(&expanded, &prefix)); + } else { + result.push_str(line); + result.push('\n'); + } + } else if is_included + && (trimmed.starts_with(".globl") || trimmed.starts_with(".global")) + { + let label = trimmed + .strip_prefix(".globl") + .or_else(|| trimmed.strip_prefix(".global")) + .unwrap_or("") + .trim(); + return Err(Error::msg(format!( + ".globl '{}' is not allowed in included files. Only the main entrypoint file \ + should declare .globl symbols.", + label + ))); + } else { + result.push_str(line); + result.push('\n'); + } + } + Ok(result) + } + // Function to compile assembly fn compile_assembly(src: &str, deploy: &str, debug: bool, arch: SbpfArch) -> Result<()> { let source_code = std::fs::read_to_string(src).unwrap(); + let base_dir = Path::new(src).parent().unwrap_or(Path::new(".")); + let source_code = expand_includes(&source_code, base_dir, false)?; let file = SimpleFile::new(src.to_string(), source_code.clone()); // Build assembler options diff --git a/tests/test_include.rs b/tests/test_include.rs new file mode 100644 index 0000000..3d74fe5 --- /dev/null +++ b/tests/test_include.rs @@ -0,0 +1,359 @@ +mod utils; + +use { + std::process::Command, + utils::{ + TestEnv, init_project, run_build, update_assembly_file, verify_project_structure, + verify_so_files, write_include_file, + }, +}; + +#[test] +fn test_include_directive() { + let env = TestEnv::new("include_test"); + + init_project(&env, "include_test"); + verify_project_structure(&env, "include_test"); + + write_include_file( + &env, + "include_test", + "log.s", + r#"custom_log: + lddw r1, message + lddw r2, 14 + call sol_log_ + exit + +.rodata + message: .ascii "Hello, Solana!" +"#, + ); + + update_assembly_file( + &env, + "include_test", + r#".globl entrypoint +.include "log.s" +.text +entrypoint: + call custom_log + exit +"#, + ); + + run_build(&env); + verify_so_files(&env); + + env.cleanup(); +} + +#[test] +fn test_include_nested() { + let env = TestEnv::new("include_nested"); + + init_project(&env, "include_nested"); + verify_project_structure(&env, "include_nested"); + + write_include_file( + &env, + "include_nested", + "log_impl.s", + r#"custom_log: + lddw r1, message + lddw r2, 14 + call sol_log_ + exit + +.rodata + message: .ascii "Nested!" +"#, + ); + + write_include_file( + &env, + "include_nested", + "log.s", + r#".include "log_impl.s" +"#, + ); + + update_assembly_file( + &env, + "include_nested", + r#".globl entrypoint +.include "log.s" +.text +entrypoint: + call custom_log + exit +"#, + ); + + run_build(&env); + verify_so_files(&env); + + env.cleanup(); +} + +#[test] +fn test_include_missing_file_fails() { + let env = TestEnv::new("include_missing"); + + init_project(&env, "include_missing"); + verify_project_structure(&env, "include_missing"); + + update_assembly_file( + &env, + "include_missing", + r#".globl entrypoint +.include "nonexistent.s" +.text +entrypoint: + exit +"#, + ); + + let output = Command::new(&env.sbpf_bin) + .current_dir(&env.project_dir) + .arg("build") + .output() + .expect("Failed to run sbpf build"); + + assert!( + !output.status.success(), + "Build should fail when include file is missing" + ); + + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("nonexistent") || stderr.contains("Failed to read"), + "Error message should mention the missing file: {}", + stderr + ); + + env.cleanup(); +} + +#[test] +fn test_include_rejects_globl_in_included_file() { + let env = TestEnv::new("include_globl"); + + init_project(&env, "include_globl"); + verify_project_structure(&env, "include_globl"); + + write_include_file( + &env, + "include_globl", + "helper.s", + r#".globl helper_fn +helper_fn: + mov64 r0, 0 + exit +"#, + ); + + update_assembly_file( + &env, + "include_globl", + r#".globl entrypoint +.include "helper.s" +.text +entrypoint: + call helper_fn + exit +"#, + ); + + let output = Command::new(&env.sbpf_bin) + .current_dir(&env.project_dir) + .arg("build") + .output() + .expect("Failed to run sbpf build"); + + assert!( + !output.status.success(), + "Build should fail when .globl is used in an included file" + ); + + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains(".globl") && stderr.contains("not allowed"), + "Error should mention .globl is not allowed in included files: {}", + stderr + ); + + env.cleanup(); +} + +#[test] +fn test_include_rejects_global_in_included_file() { + let env = TestEnv::new("include_global"); + + init_project(&env, "include_global"); + verify_project_structure(&env, "include_global"); + + write_include_file( + &env, + "include_global", + "helper.s", + r#".global helper_fn +helper_fn: + mov64 r0, 0 + exit +"#, + ); + + update_assembly_file( + &env, + "include_global", + r#".globl entrypoint +.include "helper.s" +.text +entrypoint: + call helper_fn + exit +"#, + ); + + let output = Command::new(&env.sbpf_bin) + .current_dir(&env.project_dir) + .arg("build") + .output() + .expect("Failed to run sbpf build"); + + assert!( + !output.status.success(), + "Build should fail when .global is used in an included file" + ); + + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains(".globl") && stderr.contains("not allowed"), + "Error should mention .globl is not allowed in included files: {}", + stderr + ); + + env.cleanup(); +} + +#[test] +fn test_include_auto_prefixes_data_labels() { + let env = TestEnv::new("include_prefix"); + + init_project(&env, "include_prefix"); + verify_project_structure(&env, "include_prefix"); + + write_include_file( + &env, + "include_prefix", + "log.s", + r#".text +log_msg: + lddw r1, msg + lddw r2, 13 + call sol_log_ + exit + +.rodata + msg: .ascii "from log" +"#, + ); + + write_include_file( + &env, + "include_prefix", + "math.s", + r#".text +add_one: + lddw r1, msg + lddw r2, 9 + call sol_log_ + exit + +.rodata + msg: .ascii "from math" +"#, + ); + + update_assembly_file( + &env, + "include_prefix", + r#".globl entrypoint +.include "log.s" +.include "math.s" +.text +entrypoint: + call log_msg + call add_one + mov64 r0, 0 + exit +"#, + ); + + run_build(&env); + verify_so_files(&env); + + env.cleanup(); +} + +#[test] +fn test_include_prefix_uses_path_for_same_filename_in_different_dirs() { + let env = TestEnv::new("include_path_prefix"); + + init_project(&env, "include_path_prefix"); + verify_project_structure(&env, "include_path_prefix"); + + // Two format.s files in different instruction modules - both define msg + write_include_file( + &env, + "include_path_prefix", + "instructions/transfer/format.s", + r#".text +log_transfer: + lddw r1, msg + lddw r2, 12 + call sol_log_ + exit + +.rodata + msg: .ascii "Transfer ok" +"#, + ); + + write_include_file( + &env, + "include_path_prefix", + "instructions/swap/format.s", + r#".text +log_swap: + lddw r1, msg + lddw r2, 8 + call sol_log_ + exit + +.rodata + msg: .ascii "Swap ok" +"#, + ); + + update_assembly_file( + &env, + "include_path_prefix", + r#".globl entrypoint +.include "instructions/transfer/format.s" +.include "instructions/swap/format.s" +.text +entrypoint: + call log_transfer + call log_swap + mov64 r0, 0 + exit +"#, + ); + + run_build(&env); + verify_so_files(&env); + + env.cleanup(); +} diff --git a/tests/utils.rs b/tests/utils.rs index 8911813..29d0343 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -230,3 +230,19 @@ pub fn update_assembly_file(env: &TestEnv, project_name: &str, content: &str) { .unwrap_or_else(|_| panic!("Failed to write new {}.s content", project_name)); println!("✅ Updated {}.s with specified content", project_name); } + +/// Write an include file alongside the main assembly file (supports subdirs, e.g. "transactions/utils.s") +#[allow(dead_code)] +pub fn write_include_file(env: &TestEnv, project_name: &str, filename: &str, content: &str) { + let include_path = env + .project_dir + .join(format!("src/{}/{}", project_name, filename)); + if let Some(parent) = include_path.parent() { + fs::create_dir_all(parent).unwrap_or_else(|_| { + panic!("Failed to create parent dir for include file {}", filename) + }); + } + fs::write(&include_path, content) + .unwrap_or_else(|_| panic!("Failed to write include file {}", filename)); + println!("✅ Wrote include file {}", filename); +}