diff --git a/pyrefly/lib/commands/infer.rs b/pyrefly/lib/commands/infer.rs index 6ae8dd84e..9a8b65865 100644 --- a/pyrefly/lib/commands/infer.rs +++ b/pyrefly/lib/commands/infer.rs @@ -295,7 +295,7 @@ impl InferArgs { let imports: Vec<(TextSize, String, String)> = transaction .search_exports_exact(unknown_name) .into_iter() - .map(|handle_to_import_from| { + .map(|(handle_to_import_from, _)| { insert_import_edit_with_forced_import_format( &ast, handle.dupe(), diff --git a/pyrefly/lib/lsp/non_wasm/server.rs b/pyrefly/lib/lsp/non_wasm/server.rs index b1f6dce06..d38ffab59 100644 --- a/pyrefly/lib/lsp/non_wasm/server.rs +++ b/pyrefly/lib/lsp/non_wasm/server.rs @@ -2376,7 +2376,7 @@ impl Server { let range = self.from_lsp_range(uri, &module_info, params.range); let mut actions = Vec::new(); if let Some(quickfixes) = - transaction.local_quickfix_code_actions(&handle, range, import_format) + transaction.local_quickfix_code_actions_sorted(&handle, range, import_format) { actions.extend( quickfixes diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 885c17d21..af60f631a 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ +use std::cmp::Ordering; use std::cmp::Reverse; use std::collections::BTreeMap; @@ -1679,7 +1680,7 @@ impl<'a> Transaction<'a> { } /// Produce code actions that makes edits local to the file. - pub fn local_quickfix_code_actions( + pub fn local_quickfix_code_actions_sorted( &self, handle: &Handle, range: TextRange, @@ -1695,7 +1696,9 @@ impl<'a> Transaction<'a> { let error_range = error.range(); if error_range.contains_range(range) { let unknown_name = module_info.code_at(error_range); - for handle_to_import_from in self.search_exports_exact(unknown_name) { + for (handle_to_import_from, export) in + self.search_exports_exact(unknown_name) + { let (position, insert_text, _) = insert_import_edit( &ast, self.config_finder(), @@ -1705,7 +1708,12 @@ impl<'a> Transaction<'a> { import_format, ); let range = TextRange::at(position, TextSize::new(0)); - let title = format!("Insert import: `{}`", insert_text.trim()); + let title = format!( + "Insert import: `{}`{}", + insert_text.trim(), + export.deprecation.map_or("", |_| " (deprecated)") + ); + code_actions.push((title, module_info.dupe(), range, insert_text)); } @@ -1728,7 +1736,16 @@ impl<'a> Transaction<'a> { _ => {} } } - code_actions.sort_by(|(title1, _, _, _), (title2, _, _, _)| title1.cmp(title2)); + + // Sort code actions: non-deprecated first, then alphabetically + code_actions.sort_by(|(title1, _, _, _), (title2, _, _, _)| { + match (title1.contains("deprecated"), title2.contains("deprecated")) { + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + _ => title1.cmp(title2), + } + }); + Some(code_actions) } @@ -2963,11 +2980,11 @@ impl<'a> Transaction<'a> { (result, is_incomplete) } - pub fn search_exports_exact(&self, name: &str) -> Vec { + pub fn search_exports_exact(&self, name: &str) -> Vec<(Handle, Export)> { self.search_exports(|handle, exports| { - if let Some(export) = exports.get(&Name::new(name)) { - match export { - ExportLocation::ThisModule(_) => vec![handle.dupe()], + if let Some(export_location) = exports.get(&Name::new(name)) { + match export_location { + ExportLocation::ThisModule(export) => vec![(handle.dupe(), export.clone())], // Re-exported modules like `foo` in `from from_module import foo` // should likely be ignored in autoimport suggestions // because the original export in from_module will show it. diff --git a/pyrefly/lib/test/lsp/code_actions.rs b/pyrefly/lib/test/lsp/code_actions.rs index 4ed51003a..3af8c9de5 100644 --- a/pyrefly/lib/test/lsp/code_actions.rs +++ b/pyrefly/lib/test/lsp/code_actions.rs @@ -33,7 +33,7 @@ fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String let mut report = "Code Actions Results:\n".to_owned(); let transaction = state.transaction(); for (title, info, range, patch) in transaction - .local_quickfix_code_actions( + .local_quickfix_code_actions_sorted( handle, TextRange::new(position, position), ImportFormat::Absolute, @@ -320,6 +320,90 @@ my_export ); } +#[test] +fn test_import_from_stdlib() { + let report = get_batched_lsp_operations_report_allow_error( + &[("a", "TypeVar('T')\n# ^")], + get_test_report, + ); + // TODO: Ideally `typing` would be preferred over `ast`. + assert_eq!( + r#" +# a.py +1 | TypeVar('T') + ^ +Code Actions Results: +# Title: Insert import: `from ast import TypeVar` + +## Before: +TypeVar('T') +# ^ +## After: +from ast import TypeVar +TypeVar('T') +# ^ +# Title: Insert import: `from typing import TypeVar` + +## Before: +TypeVar('T') +# ^ +## After: +from typing import TypeVar +TypeVar('T') +# ^ +"# + .trim(), + report.trim() + ); +} + +#[test] +fn test_take_deprecation_into_account_in_sorting_of_actions() { + let report = get_batched_lsp_operations_report_allow_error( + &[ + ( + "a", + "from warnings import deprecated\n@deprecated('')\ndef my_func(): pass", + ), + ("b", "def my_func(): pass"), + ("c", "my_func()\n# ^"), + ], + get_test_report, + ); + assert_eq!( + r#" +# a.py + +# b.py + +# c.py +1 | my_func() + ^ +Code Actions Results: +# Title: Insert import: `from b import my_func` + +## Before: +my_func() +# ^ +## After: +from b import my_func +my_func() +# ^ +# Title: Insert import: `from a import my_func` (deprecated) + +## Before: +my_func() +# ^ +## After: +from a import my_func +my_func() +# ^ +"# + .trim(), + report.trim() + ); +} + #[test] fn extract_function_basic_refactor() { let code = r#"