@@ -31,6 +31,7 @@ pub fn swe_bench_schema() -> Schema {
3131 Field :: new( "PASS_TO_PASS" , DataType :: Utf8 , false ) ,
3232 Field :: new( "environment_setup_commit" , DataType :: Utf8 , true ) ,
3333 // swe-forge extensions
34+ Field :: new( "install" , DataType :: Utf8 , true ) ,
3435 Field :: new( "language" , DataType :: Utf8 , false ) ,
3536 Field :: new( "difficulty" , DataType :: Utf8 , false ) ,
3637 Field :: new( "difficulty_score" , DataType :: UInt8 , false ) ,
@@ -54,6 +55,7 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {
5455 let mut fail_to_pass = StringBuilder :: new ( ) ;
5556 let mut pass_to_pass = StringBuilder :: new ( ) ;
5657 let mut env_setup_commit = StringBuilder :: new ( ) ;
58+ let mut install = StringBuilder :: new ( ) ;
5759 let mut language = StringBuilder :: new ( ) ;
5860 let mut difficulty = StringBuilder :: new ( ) ;
5961 let mut difficulty_score = UInt8Builder :: new ( ) ;
@@ -103,6 +105,13 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {
103105 env_setup_commit. append_value ( & env_commit) ;
104106 }
105107
108+ let install_cmd = task. install_config . get ( "install" ) . cloned ( ) . unwrap_or_default ( ) ;
109+ if install_cmd. is_empty ( ) {
110+ install. append_null ( ) ;
111+ } else {
112+ install. append_value ( & install_cmd) ;
113+ }
114+
106115 language. append_value ( & task. language ) ;
107116
108117 let diff_label =
@@ -136,6 +145,7 @@ pub fn tasks_to_record_batch(tasks: &[SweTask]) -> anyhow::Result<RecordBatch> {
136145 Arc :: new( fail_to_pass. finish( ) ) ,
137146 Arc :: new( pass_to_pass. finish( ) ) ,
138147 Arc :: new( env_setup_commit. finish( ) ) ,
148+ Arc :: new( install. finish( ) ) ,
139149 Arc :: new( language. finish( ) ) ,
140150 Arc :: new( difficulty. finish( ) ) ,
141151 Arc :: new( difficulty_score. finish( ) ) ,
@@ -239,6 +249,7 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
239249 let created_ats = get_string ( "created_at" ) ;
240250 let fail_to_passes = get_string ( "FAIL_TO_PASS" ) ;
241251 let pass_to_passes = get_string ( "PASS_TO_PASS" ) ;
252+ let installs = get_string ( "install" ) ;
242253 let languages = get_string ( "language" ) ;
243254 let difficulties = get_string ( "difficulty" ) ;
244255
@@ -311,6 +322,12 @@ pub fn read_parquet(input_path: &Path) -> anyhow::Result<Vec<SweTask>> {
311322 task. meta . insert ( "difficulty" . to_string ( ) , d. clone ( ) ) ;
312323 }
313324
325+ if let Some ( ref inst) = installs[ i] {
326+ if !inst. is_empty ( ) {
327+ task. install_config . insert ( "install" . to_string ( ) , inst. clone ( ) ) ;
328+ }
329+ }
330+
314331 tasks. push ( task) ;
315332 }
316333 }
@@ -342,6 +359,8 @@ mod tests {
342359 task. pass_to_pass = vec ! [ "pytest tests/test_x.py::test_other" . to_string( ) ] ;
343360 task. meta
344361 . insert ( "difficulty" . to_string ( ) , "medium" . to_string ( ) ) ;
362+ task. install_config
363+ . insert ( "install" . to_string ( ) , "pip install -e ." . to_string ( ) ) ;
345364 task
346365 }
347366
@@ -354,15 +373,15 @@ mod tests {
354373 assert ! ( schema. field_with_name( "PASS_TO_PASS" ) . is_ok( ) ) ;
355374 assert ! ( schema. field_with_name( "difficulty" ) . is_ok( ) ) ;
356375 assert ! ( schema. field_with_name( "quality_score" ) . is_ok( ) ) ;
357- assert_eq ! ( schema. fields( ) . len( ) , 16 ) ;
376+ assert_eq ! ( schema. fields( ) . len( ) , 17 ) ;
358377 }
359378
360379 #[ test]
361380 fn test_tasks_to_record_batch ( ) {
362381 let tasks = vec ! [ make_test_task( "task-001" ) , make_test_task( "task-002" ) ] ;
363382 let batch = tasks_to_record_batch ( & tasks) . unwrap ( ) ;
364383 assert_eq ! ( batch. num_rows( ) , 2 ) ;
365- assert_eq ! ( batch. num_columns( ) , 16 ) ;
384+ assert_eq ! ( batch. num_columns( ) , 17 ) ;
366385 }
367386
368387 #[ test]
@@ -380,6 +399,10 @@ mod tests {
380399 assert_eq ! ( loaded[ 0 ] . language, "python" ) ;
381400 assert_eq ! ( loaded[ 0 ] . difficulty_score, 2 ) ;
382401 assert_eq ! ( loaded[ 0 ] . fail_to_pass. len( ) , 1 ) ;
402+ assert_eq ! (
403+ loaded[ 0 ] . install_config. get( "install" ) . map( |s| s. as_str( ) ) ,
404+ Some ( "pip install -e ." )
405+ ) ;
383406
384407 let _ = std:: fs:: remove_file ( & tmp) ;
385408 }
0 commit comments