Skip to content

Commit f4c148a

Browse files
committed
feat: add install column to parquet schema for coherence with HuggingFace dataset
- parquet_writer.rs: add 'install' field to schema, write from install_config, read back on load - dataset.rs: add 'install' to dataset card YAML features and markdown table - All 380 tests pass, round-trip verified
1 parent 8478f40 commit f4c148a

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

src/export/dataset.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ dataset_info:
331331
dtype: string
332332
- name: environment_setup_commit
333333
dtype: string
334+
- name: install
335+
dtype: string
334336
- name: language
335337
dtype: string
336338
- name: difficulty
@@ -366,6 +368,7 @@ with additional fields for multi-language support, difficulty scoring, and quali
366368
| `FAIL_TO_PASS` | string | JSON list of tests that must pass after fix |
367369
| `PASS_TO_PASS` | string | JSON list of regression tests |
368370
| `environment_setup_commit` | string | Commit for environment setup |
371+
| `install` | string | Verified install commands for environment setup |
369372
| `language` | string | Primary programming language |
370373
| `difficulty` | string | Difficulty level (easy/medium/hard) |
371374
| `difficulty_score` | uint8 | Numeric difficulty (1=easy, 2=medium, 3=hard) |

src/export/parquet_writer.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub fn swe_bench_schema() -> Schema {
3131
Field::new("PASS_TO_PASS", DataType::Utf8, false),
3232
Field::new("environment_setup_commit", DataType::Utf8, true),
3333
// swe-forge extensions
34+
Field::new("install", DataType::Utf8, true),
3435
Field::new("language", DataType::Utf8, false),
3536
Field::new("difficulty", DataType::Utf8, false),
3637
Field::new("difficulty_score", DataType::UInt8, false),
@@ -54,6 +55,7 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {
5455
let mut fail_to_pass = StringBuilder::new();
5556
let mut pass_to_pass = StringBuilder::new();
5657
let mut env_setup_commit = StringBuilder::new();
58+
let mut install = StringBuilder::new();
5759
let mut language = StringBuilder::new();
5860
let mut difficulty = StringBuilder::new();
5961
let mut difficulty_score = UInt8Builder::new();
@@ -103,6 +105,13 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {
103105
env_setup_commit.append_value(&env_commit);
104106
}
105107

108+
let install_cmd = task.install_config.get("install").cloned().unwrap_or_default();
109+
if install_cmd.is_empty() {
110+
install.append_null();
111+
} else {
112+
install.append_value(&install_cmd);
113+
}
114+
106115
language.append_value(&task.language);
107116

108117
let diff_label =
@@ -136,6 +145,7 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {
136145
Arc::new(fail_to_pass.finish()),
137146
Arc::new(pass_to_pass.finish()),
138147
Arc::new(env_setup_commit.finish()),
148+
Arc::new(install.finish()),
139149
Arc::new(language.finish()),
140150
Arc::new(difficulty.finish()),
141151
Arc::new(difficulty_score.finish()),
@@ -239,6 +249,7 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
239249
let created_ats = get_string("created_at");
240250
let fail_to_passes = get_string("FAIL_TO_PASS");
241251
let pass_to_passes = get_string("PASS_TO_PASS");
252+
let installs = get_string("install");
242253
let languages = get_string("language");
243254
let difficulties = get_string("difficulty");
244255

@@ -311,6 +322,12 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
311322
task.meta.insert("difficulty".to_string(), d.clone());
312323
}
313324

325+
if let Some(ref inst) = installs[i] {
326+
if !inst.is_empty() {
327+
task.install_config.insert("install".to_string(), inst.clone());
328+
}
329+
}
330+
314331
tasks.push(task);
315332
}
316333
}
@@ -342,6 +359,8 @@ mod tests {
342359
task.pass_to_pass = vec!["pytest tests/test_x.py::test_other".to_string()];
343360
task.meta
344361
.insert("difficulty".to_string(), "medium".to_string());
362+
task.install_config
363+
.insert("install".to_string(), "pip install -e .".to_string());
345364
task
346365
}
347366

@@ -354,15 +373,15 @@ mod tests {
354373
assert!(schema.field_with_name("PASS_TO_PASS").is_ok());
355374
assert!(schema.field_with_name("difficulty").is_ok());
356375
assert!(schema.field_with_name("quality_score").is_ok());
357-
assert_eq!(schema.fields().len(), 16);
376+
assert_eq!(schema.fields().len(), 17);
358377
}
359378

360379
#[test]
361380
fn test_tasks_to_record_batch() {
362381
let tasks = vec![make_test_task("task-001"), make_test_task("task-002")];
363382
let batch = tasks_to_record_batch(&tasks).unwrap();
364383
assert_eq!(batch.num_rows(), 2);
365-
assert_eq!(batch.num_columns(), 16);
384+
assert_eq!(batch.num_columns(), 17);
366385
}
367386

368387
#[test]
@@ -380,6 +399,10 @@ mod tests {
380399
assert_eq!(loaded[0].language, "python");
381400
assert_eq!(loaded[0].difficulty_score, 2);
382401
assert_eq!(loaded[0].fail_to_pass.len(), 1);
402+
assert_eq!(
403+
loaded[0].install_config.get("install").map(|s| s.as_str()),
404+
Some("pip install -e .")
405+
);
383406

384407
let _ = std::fs::remove_file(&tmp);
385408
}

0 commit comments

Comments
 (0)