diff --git a/crates/steel-core/src/compiler/passes/analysis.rs b/crates/steel-core/src/compiler/passes/analysis.rs index 340aee7a0..d06bc8e2c 100644 --- a/crates/steel-core/src/compiler/passes/analysis.rs +++ b/crates/steel-core/src/compiler/passes/analysis.rs @@ -470,6 +470,14 @@ impl Analysis { Some(id) } + pub fn resolve_reference(&self, mut id: SyntaxObjectId) -> SyntaxObjectId { + while let Some(next) = self.info.get(&id).and_then(|x| x.refers_to) { + id = next; + } + + id + } + pub fn visit_top_level_define_function_without_body( &mut self, define: &crate::parser::ast::Define, @@ -4818,6 +4826,10 @@ impl<'a> SemanticAnalysis<'a> { self.analysis.resolve_alias(id) } + pub fn resolve_reference(&self, id: SyntaxObjectId) -> SyntaxObjectId { + self.analysis.resolve_reference(id) + } + pub fn flatten_anonymous_functions(&mut self) { FlattenAnonymousFunctionCalls::flatten(&self.analysis, self.exprs); } @@ -5744,6 +5756,35 @@ mod analysis_pass_tests { } } + #[test] + fn resolve_reference() { + let script = r#" + (define (double number) + (+ number number)) + "#; + + let mut exprs = Parser::parse(script).unwrap(); + let analysis = SemanticAnalysis::new(&mut exprs); + + let identifiers = analysis + .analysis + .identifier_info() + .iter() + .filter(|(_, semantic_info)| semantic_info.kind == IdentifierStatus::Local) + .map(|(id, _)| analysis.resolve_reference(*id)) + .collect::>(); + + assert_eq!(identifiers.len(), 3); + + for window in identifiers.windows(2) { + let [left, right] = window else { + unreachable!() + }; + + assert_eq!(left, right); + } + } + #[test] fn test_capture() { let script = r#" diff --git a/crates/steel-language-server/src/backend.rs b/crates/steel-language-server/src/backend.rs index 7df624b1e..31f292161 100644 --- a/crates/steel-language-server/src/backend.rs +++ b/crates/steel-language-server/src/backend.rs @@ -18,18 +18,22 @@ use steel::{ compiler::{ modules::{steel_home, MANGLER_PREFIX, MODULE_PREFIX}, passes::analysis::{ - query_top_level_define, query_top_level_define_on_condition, + query_top_level_define, query_top_level_define_on_condition, IdentifierStatus, RequiredIdentifierInformation, SemanticAnalysis, }, }, parser::{ - ast::ExprKind, expander::SteelMacro, interner::InternedString, parser::SourceId, - span::Span, tryfrom_visitor::SyntaxObjectFromExprKindRef, + ast::ExprKind, + expander::SteelMacro, + interner::InternedString, + parser::{Parser, SourceId}, + span::Span, + tryfrom_visitor::SyntaxObjectFromExprKindRef, }, rvals::{FromSteelVal, SteelString}, steel_vm::{builtin::BuiltInModule, engine::Engine, register_fn::RegisterFn}, }; -use tower_lsp::jsonrpc::Result; +use tower_lsp::jsonrpc::{self, Result}; use tower_lsp::lsp_types::notification::Notification; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer}; @@ -70,6 +74,7 @@ pub const LEGEND_TYPE: &[SemanticTokenType] = &[ pub struct Backend { pub client: Client, pub ast_map: DashMap>, + pub lowered_ast_map: DashMap>, pub document_map: DashMap, // TODO: This needs to hold macros to help with resolving definitions pub _macro_map: DashMap>, @@ -135,7 +140,12 @@ impl LanguageServer for Backend { ), definition_provider: Some(OneOf::Left(true)), references_provider: Some(OneOf::Left(true)), - rename_provider: Some(OneOf::Left(true)), + rename_provider: Some(OneOf::Right(RenameOptions { + prepare_provider: Some(true), + work_done_progress_options: WorkDoneProgressOptions { + work_done_progress: None, + }, + })), hover_provider: Some(HoverProviderCapability::Simple(true)), ..ServerCapabilities::default() }, @@ -610,8 +620,107 @@ impl LanguageServer for Backend { Ok(completions.map(CompletionResponse::Array)) } - async fn rename(&self, _params: RenameParams) -> Result> { - Ok(None) + async fn prepare_rename( + &self, + params: TextDocumentPositionParams, + ) -> Result> { + let uri = params.text_document.uri; + let position = params.position; + + let Some((identifier, range)) = || -> Option<_> { + let rope = self.document_map.get(uri.as_str())?; + let mut ast = self.lowered_ast_map.get_mut(uri.as_str())?; + + let offset = position_to_offset(position, &rope)?; + let semantic = SemanticAnalysis::new(&mut ast); + let (_, identifier) = + semantic.find_identifier_at_offset(offset, uri_to_source_id(&uri)?)?; + + let range = Range::new( + offset_to_position(identifier.span.start, &rope)?, + offset_to_position(identifier.span.end, &rope)?, + ); + + Some((identifier.clone(), range)) + }() else { + return Ok(None); + }; + + if identifier.builtin { + return Err(jsonrpc::Error::invalid_params("cannot rename builtin")); + } + + if !matches!( + identifier.kind, + IdentifierStatus::Local + | IdentifierStatus::LetVar + | IdentifierStatus::LocallyDefinedFunction + ) { + return Err(jsonrpc::Error::invalid_params(format!( + "cannot rename symbol of kind {:?}", + identifier.kind + ))); + } + + Ok(Some(PrepareRenameResponse::Range(range))) + } + + async fn rename(&self, params: RenameParams) -> Result> { + let uri = params.text_document_position.text_document.uri; + let position = params.text_document_position.position; + + let changes = || -> Option> { + let rope = self.document_map.get(uri.as_str())?; + let mut ast = self.lowered_ast_map.get_mut(uri.as_str())?; + + let offset = position_to_offset(position, &rope)?; + let semantic = SemanticAnalysis::new(&mut ast); + let (syntax_object_id, semantic_information) = + semantic.find_identifier_at_offset(offset, uri_to_source_id(&uri)?)?; + + // it should probaby not be possible to rename builtins ... + if semantic_information.builtin { + return None; + } + + let syntax_object_id = semantic.analysis.resolve_reference(*syntax_object_id); + let semantic_information = semantic.get_identifier(syntax_object_id).unwrap(); + + // it might make sense to be able to rename other things as well, + // but i think this is at least good start + if !matches!( + semantic_information.kind, + IdentifierStatus::Local + | IdentifierStatus::LetVar + | IdentifierStatus::LocallyDefinedFunction, + ) { + return None; + } + + let identifier_info = semantic.analysis.identifier_info(); + let identifiers = identifier_info + .iter() + .filter(|(&id, _)| semantic.analysis.resolve_reference(id) == syntax_object_id) + .filter(|(_, info)| info.kind == semantic_information.kind) + .map(|(_, information)| (information.span.start, information.span.end)) + .filter_map(|(start, end)| { + Some(Range::new( + offset_to_position(start, &rope)?, + offset_to_position(end, &rope)?, + )) + }) + .map(|range| TextEdit::new(range, params.new_name.clone())) + .collect::>(); + + Some(identifiers) + }(); + + let Some(changes) = changes else { + return Ok(None); + }; + + let changes = HashMap::from_iter([(uri, changes)]); + Ok(Some(WorkspaceEdit::new(changes))) } async fn did_change_configuration(&self, _: DidChangeConfigurationParams) { @@ -814,6 +923,17 @@ impl Backend { self.ast_map.insert(params.uri.to_string(), ast); + // the ast that is parsed for the `ast_map` is parsed with the `.without_lowering` + // argument to the `Parser`. but for things like `rename` (and `prepare_rename`), + // i need an ast that is parsed without that argument, so instead of having to recalculate it on-demand, + // just do it here, once. + if let Ok(lowered_ast) = + Parser::new(&expression, id).collect::, _>>() + { + self.lowered_ast_map + .insert(params.uri.to_string(), lowered_ast); + } + diagnostics }; diff --git a/crates/steel-language-server/src/main.rs b/crates/steel-language-server/src/main.rs index f59964d0b..5c628cf69 100644 --- a/crates/steel-language-server/src/main.rs +++ b/crates/steel-language-server/src/main.rs @@ -75,6 +75,7 @@ async fn main() { let (service, socket) = LspService::build(|client| Backend { client, ast_map: DashMap::new(), + lowered_ast_map: DashMap::new(), document_map: DashMap::new(), _macro_map: DashMap::new(), globals_set,