Skip to content

Commit

Permalink
refactor: adapt py back to old pyo3 0.21.2 API
Browse files Browse the repository at this point in the history
  • Loading branch information
jqnatividad committed Dec 2, 2024
1 parent 7f9fc8a commit ad1c0c8
Showing 1 changed file with 33 additions and 50 deletions.
83 changes: 33 additions & 50 deletions src/cmd/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ py options:
-b, --batch <size> The number of rows per batch to process before
releasing memory and acquiring a new GILpool.
Set to 0 to process the entire file in one batch.
[default: 50000]
See https://pyo3.rs/v0.21.0/memory.html#gil-bound-memory
for more info. [default: 50000]
Common options:
-h, --help Display this message
Expand All @@ -122,7 +123,7 @@ Common options:
-p, --progressbar Show progress bars. Not valid for stdin.
"#;

use std::{ffi::CString, fs};
use std::fs;

use indicatif::{ProgressBar, ProgressDrawTarget};
use pyo3::{
Expand Down Expand Up @@ -205,22 +206,23 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
});
}

let expression = if let Some(expression_filepath) = args.arg_expression.strip_prefix("file:") {
match fs::read_to_string(expression_filepath) {
Ok(file_contents) => file_contents,
Err(e) => return fail_clierror!("Cannot load Python expression from file: {e}"),
}
} else if std::path::Path::new(&args.arg_expression)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("py"))
{
match fs::read_to_string(args.arg_expression.clone()) {
Ok(file_contents) => file_contents,
Err(e) => return fail_clierror!("Cannot load .py file: {e}"),
}
} else {
args.arg_expression.clone()
};
let arg_expression =
if let Some(expression_filepath) = args.arg_expression.strip_prefix("file:") {
match fs::read_to_string(expression_filepath) {
Ok(file_contents) => file_contents,
Err(e) => return fail_clierror!("Cannot load Python expression from file: {e}"),
}
} else if std::path::Path::new(&args.arg_expression)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("py"))
{
match fs::read_to_string(args.arg_expression.clone()) {
Ok(file_contents) => file_contents,
Err(e) => return fail_clierror!("Cannot load .py file: {e}"),
}
} else {
args.arg_expression.clone()
};

let mut helper_text = String::new();
if let Some(helper_file) = args.flag_helper {
Expand Down Expand Up @@ -278,22 +280,8 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
// reuse batch buffers
let mut batch = Vec::with_capacity(batch_size);

// safety: safe to unwrap as these are statically defined
let helpers_code = CString::new(HELPERS).unwrap();
let helpers_filename = CString::new("qsv_helpers.py").unwrap();
let helpers_module_name = CString::new("qsv_helpers").unwrap();

let user_helpers_code = CString::new(helper_text)
.map_err(|e| format!("Failed to create CString from helper text: {e}"))?;

// safety: safe to unwrap as these are statically defined
let user_helpers_filename = CString::new("qsv_user_helpers.py").unwrap();
let user_helpers_module_name = CString::new("qsv_uh").unwrap();

let arg_expression = CString::new(expression)
.map_err(|e| format!("Failed to create CString from expression: {e}"))?;

let mut row_number = 0_u64;
let debug_flag = log::log_enabled!(log::Level::Debug);

// main loop to read CSV and construct batches.
// we batch python operations so that the GILPool does not get very large
Expand Down Expand Up @@ -324,25 +312,19 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

Python::with_gil(|py| -> PyResult<()> {
let batch_ref = &mut batch;
let helpers = PyModule::from_code_bound(py, HELPERS, "qsv_helpers.py", "qsv_helpers")?;
let batch_globals = PyDict::new_bound(py);
let batch_locals = PyDict::new_bound(py);

let helpers =
PyModule::from_code(py, &helpers_code, &helpers_filename, &helpers_module_name)?;
let batch_globals = PyDict::new(py);
let batch_locals = PyDict::new(py);

let user_helpers = PyModule::from_code(
py,
&user_helpers_code,
&user_helpers_filename,
&user_helpers_module_name,
)?;
let user_helpers =
PyModule::from_code_bound(py, &helper_text, "qsv_user_helpers.py", "qsv_uh")?;
batch_globals.set_item(intern!(py, "qsv_uh"), user_helpers)?;

// Global imports
let builtins = PyModule::import(py, "builtins")?;
let math_module = PyModule::import(py, "math")?;
let random_module = PyModule::import(py, "random")?;
let datetime_module = PyModule::import(py, "datetime")?;
let builtins = PyModule::import_bound(py, "builtins")?;
let math_module = PyModule::import_bound(py, "math")?;
let random_module = PyModule::import_bound(py, "random")?;
let datetime_module = PyModule::import_bound(py, "datetime")?;

batch_globals.set_item("__builtins__", builtins)?;
batch_globals.set_item("math", math_module)?;
Expand Down Expand Up @@ -385,8 +367,9 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
py_row.call_method1(intern!(py, "_update_underlying_data"), (row_data,))?;

let result =
match py.eval(&arg_expression, Some(&batch_globals), Some(&batch_locals)) {
Ok(r) => r,
match py.eval_bound(&arg_expression, Some(&batch_globals), Some(&batch_locals))
{
Ok(result) => result,
Err(e) => {
error_count += 1;
if debug_flag {
Expand Down

0 comments on commit ad1c0c8

Please sign in to comment.