Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,12 @@ async fn run_swe_mine_command(args: SweMineArgs) -> anyhow::Result<()> {
None => None,
};

let (effective_max_tasks, effective_difficulty_filter) = if let Some(ref dt) = difficulty_targets
{
(dt.total_tasks(), None)
} else {
(args.max_tasks, args.difficulty.clone())
};
let (effective_max_tasks, effective_difficulty_filter) =
if let Some(ref dt) = difficulty_targets {
(dt.total_tasks(), None)
} else {
(args.max_tasks, args.difficulty.clone())
};

let hf_upload = match (&args.hf_repo, &args.hf_token) {
(Some(repo), Some(token)) => Some(crate::export::HfUploadConfig {
Expand Down Expand Up @@ -819,7 +819,8 @@ async fn run_swe_mine_command(args: SweMineArgs) -> anyhow::Result<()> {
);

// Show per-difficulty breakdown
let mut per_level: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut per_level: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for task in &result.tasks {
if task.quality_passed {
let level = task
Expand Down Expand Up @@ -868,10 +869,16 @@ async fn run_swe_load_command(args: SweLoadArgs) -> anyhow::Result<()> {
}

// Compute stats
let mut by_difficulty: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut by_language: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut by_difficulty: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
let mut by_language: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for task in &tasks {
let diff = task.meta.get("difficulty").cloned().unwrap_or_else(|| "unknown".to_string());
let diff = task
.meta
.get("difficulty")
.cloned()
.unwrap_or_else(|| "unknown".to_string());
*by_difficulty.entry(diff).or_insert(0) += 1;
*by_language.entry(task.language.clone()).or_insert(0) += 1;
}
Expand Down
39 changes: 28 additions & 11 deletions src/export/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,21 @@ impl DatasetManager {
let mut all_tasks = Vec::new();
let mut entries: Vec<_> = std::fs::read_dir(&data_dir)?
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map(|x| x == "parquet").unwrap_or(false))
.filter(|e| {
e.path()
.extension()
.map(|x| x == "parquet")
.unwrap_or(false)
})
.collect();
entries.sort_by_key(|e| e.file_name());

for entry in &entries {
match parquet_writer::read_parquet(&entry.path()) {
Ok(tasks) => all_tasks.extend(tasks),
Err(e) => tracing::warn!(path = %entry.path().display(), error = %e, "Failed to read shard"),
Err(e) => {
tracing::warn!(path = %entry.path().display(), error = %e, "Failed to read shard")
}
}
}

Expand All @@ -188,12 +195,17 @@ impl DatasetManager {

// Upload combined + splits to HF
if let Some(ref uploader) = self.uploader {
let combined_bytes = std::fs::read(self.config.output_dir.join("train.parquet"))?;
let combined_bytes =
std::fs::read(self.config.output_dir.join("train.parquet"))?;
let _ = uploader
.upload_file("train.parquet", &combined_bytes, "Add combined train.parquet")
.upload_file(
"train.parquet",
&combined_bytes,
"Add combined train.parquet",
)
.await;

for (diff, _) in &by_diff {
for diff in by_diff.keys() {
let split_path = self.config.output_dir.join(format!("{}.parquet", diff));
if let Ok(bytes) = std::fs::read(&split_path) {
let _ = uploader
Expand Down Expand Up @@ -373,7 +385,12 @@ pub fn load_dataset(path: &Path) -> anyhow::Result<Vec<SweTask>> {
let mut all_tasks = Vec::new();
let mut entries: Vec<_> = std::fs::read_dir(path)?
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map(|x| x == "parquet").unwrap_or(false))
.filter(|e| {
e.path()
.extension()
.map(|x| x == "parquet")
.unwrap_or(false)
})
.collect();
entries.sort_by_key(|e| e.file_name());

Expand All @@ -384,7 +401,10 @@ pub fn load_dataset(path: &Path) -> anyhow::Result<Vec<SweTask>> {
return Ok(all_tasks);
}

anyhow::bail!("Path is neither a parquet file nor a directory: {}", path.display());
anyhow::bail!(
"Path is neither a parquet file nor a directory: {}",
path.display()
);
}

/// Download a dataset from HuggingFace and return the tasks.
Expand All @@ -406,10 +426,7 @@ pub async fn download_dataset(
tracing::info!(repo = repo_id, file = %filename, "Downloading dataset from HuggingFace");

let client = reqwest::Client::new();
let resp = client
.get(&url)
.send()
.await?;
let resp = client.get(&url).send().await?;

if !resp.status().is_success() {
anyhow::bail!(
Expand Down
19 changes: 9 additions & 10 deletions src/export/hf_uploader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ impl HfUploader {
HF_API_BASE, self.config.repo_id
);

let encoded = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
content,
);
let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, content);

let body = CommitRequest {
summary: commit_message.to_string(),
Expand All @@ -140,7 +137,10 @@ impl HfUploader {
repo = %self.config.repo_id,
"Uploaded file to HF"
);
self.uploaded_files.lock().await.push(path_in_repo.to_string());
self.uploaded_files
.lock()
.await
.push(path_in_repo.to_string());
Ok(())
} else {
let status = resp.status();
Expand All @@ -157,7 +157,8 @@ impl HfUploader {
commit_message: &str,
) -> anyhow::Result<()> {
let content = std::fs::read(local_path)?;
self.upload_file(path_in_repo, &content, commit_message).await
self.upload_file(path_in_repo, &content, commit_message)
.await
}

/// Upload multiple files in a single commit (more efficient).
Expand All @@ -178,10 +179,8 @@ impl HfUploader {
let actions: Vec<CommitAction> = files
.iter()
.map(|(path, content)| {
let encoded = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
content,
);
let encoded =
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, content);
CommitAction {
action: "file".to_string(),
path: path.to_string(),
Expand Down
73 changes: 48 additions & 25 deletions src/export/parquet_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {

created_at.append_value(task.created_at.to_rfc3339());

let ver = task
.meta
.get("version")
.cloned()
.unwrap_or_default();
let ver = task.meta.get("version").cloned().unwrap_or_default();
if ver.is_empty() {
version.append_null();
} else {
Expand All @@ -109,15 +105,15 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {

language.append_value(&task.language);

let diff_label = task
.meta
.get("difficulty")
.cloned()
.unwrap_or_else(|| match task.difficulty_score {
0..=1 => "easy".to_string(),
2 => "medium".to_string(),
_ => "hard".to_string(),
});
let diff_label =
task.meta
.get("difficulty")
.cloned()
.unwrap_or_else(|| match task.difficulty_score {
0..=1 => "easy".to_string(),
2 => "medium".to_string(),
_ => "hard".to_string(),
});
difficulty.append_value(&diff_label);
difficulty_score.append_value(task.difficulty_score);

Expand Down Expand Up @@ -219,9 +215,17 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
batch
.column_by_name(name)
.and_then(|col| col.as_any().downcast_ref::<StringArray>())
.map(|arr| (0..num_rows).map(|i| {
if arr.is_null(i) { None } else { Some(arr.value(i).to_string()) }
}).collect())
.map(|arr| {
(0..num_rows)
.map(|i| {
if arr.is_null(i) {
None
} else {
Some(arr.value(i).to_string())
}
})
.collect()
})
.unwrap_or_else(|| vec![None; num_rows])
};

Expand All @@ -241,15 +245,27 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
let difficulty_scores: Vec<u8> = batch
.column_by_name("difficulty_score")
.and_then(|col| col.as_any().downcast_ref::<arrow::array::UInt8Array>())
.map(|arr| (0..num_rows).map(|i| if arr.is_null(i) { 1 } else { arr.value(i) }).collect())
.map(|arr| {
(0..num_rows)
.map(|i| if arr.is_null(i) { 1 } else { arr.value(i) })
.collect()
})
.unwrap_or_else(|| vec![1; num_rows]);

let quality_scores: Vec<Option<f64>> = batch
.column_by_name("quality_score")
.and_then(|col| col.as_any().downcast_ref::<arrow::array::Float64Array>())
.map(|arr| (0..num_rows).map(|i| {
if arr.is_null(i) { None } else { Some(arr.value(i)) }
}).collect())
.map(|arr| {
(0..num_rows)
.map(|i| {
if arr.is_null(i) {
None
} else {
Some(arr.value(i))
}
})
.collect()
})
.unwrap_or_else(|| vec![None; num_rows]);

for i in 0..num_rows {
Expand All @@ -259,8 +275,12 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
continue;
}

let f2p_str = fail_to_passes[i].clone().unwrap_or_else(|| "[]".to_string());
let p2p_str = pass_to_passes[i].clone().unwrap_or_else(|| "[]".to_string());
let f2p_str = fail_to_passes[i]
.clone()
.unwrap_or_else(|| "[]".to_string());
let p2p_str = pass_to_passes[i]
.clone()
.unwrap_or_else(|| "[]".to_string());
let fail_to_pass: Vec<String> = serde_json::from_str(&f2p_str).unwrap_or_default();
let pass_to_pass: Vec<String> = serde_json::from_str(&p2p_str).unwrap_or_default();

Expand All @@ -276,7 +296,9 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
task.test_patch = test_patches[i].clone().unwrap_or_default();
task.prompt = problem_statements[i].clone().unwrap_or_default();
task.original_pr_body = hints[i].clone().unwrap_or_default();
task.language = languages[i].clone().unwrap_or_else(|| "unknown".to_string());
task.language = languages[i]
.clone()
.unwrap_or_else(|| "unknown".to_string());
task.difficulty_score = difficulty_scores[i];
task.quality_score = quality_scores[i];
task.quality_passed = true;
Expand Down Expand Up @@ -318,7 +340,8 @@ mod tests {
task.quality_passed = true;
task.fail_to_pass = vec!["pytest tests/test_x.py::test_fix".to_string()];
task.pass_to_pass = vec!["pytest tests/test_x.py::test_other".to_string()];
task.meta.insert("difficulty".to_string(), "medium".to_string());
task.meta
.insert("difficulty".to_string(), "medium".to_string());
task
}

Expand Down
10 changes: 9 additions & 1 deletion src/swe/docker_sandbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,15 @@ impl DockerSandbox {
let result = tokio::time::timeout(
std::time::Duration::from_millis(timeout_ms),
Command::new("docker")
.args(["exec", "-w", "/repo", &self.container_name, "bash", "-c", cmd])
.args([
"exec",
"-w",
"/repo",
&self.container_name,
"bash",
"-c",
cmd,
])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output(),
Expand Down
10 changes: 2 additions & 8 deletions src/swe/enricher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,8 @@ async fn fetch_pr_files_info(
if let Some(path) = file.get("filename").and_then(Value::as_str) {
info.file_paths.push(path.to_string());
}
info.added_lines += file
.get("additions")
.and_then(Value::as_u64)
.unwrap_or(0) as usize;
info.removed_lines += file
.get("deletions")
.and_then(Value::as_u64)
.unwrap_or(0) as usize;
info.added_lines += file.get("additions").and_then(Value::as_u64).unwrap_or(0) as usize;
info.removed_lines += file.get("deletions").and_then(Value::as_u64).unwrap_or(0) as usize;
}
Ok(info)
}
15 changes: 9 additions & 6 deletions src/swe/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,19 @@ impl PatchExtractor {
(Some(base), Some(merge)) if !base.is_empty() && !merge.is_empty() => {
// Fetch the merge commit (shallow clone may not have it)
sandbox
.exec(&format!("git fetch origin {} --depth=1 2>&1", merge), 60_000)
.exec(
&format!("git fetch origin {} --depth=1 2>&1", merge),
60_000,
)
.await;
format!("{base}..{merge}")
}
(_, Some(merge)) if !merge.is_empty() => {
sandbox
.exec(&format!("git fetch origin {} --depth=1 2>&1", merge), 60_000)
.exec(
&format!("git fetch origin {} --depth=1 2>&1", merge),
60_000,
)
.await;
merge.to_string()
}
Expand All @@ -231,10 +237,7 @@ impl PatchExtractor {
sandbox.destroy().await;

if result.exit_code != 0 {
anyhow::bail!(
"git show failed in Docker: {}",
&result.stderr
);
anyhow::bail!("git show failed in Docker: {}", &result.stderr);
}

Ok(result.stdout)
Expand Down
Loading
Loading