Skip to content

Commit 06692ab

Browse files
committed
uvx reformat
1 parent 80ec4c6 commit 06692ab

File tree

1 file changed

+81
-26
lines changed

1 file changed

+81
-26
lines changed

notebooks/advanced_04_conformal_prediction.ipynb

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,17 @@
128128
"# Load BBBP dataset for classification\n",
129129
"print(\"Loading BBBP dataset for classification...\")\n",
130130
"bbbp_df = pd.read_csv(\n",
131-
" test_data_dir / \"molecule_net_bbbp.tsv.gz\", sep=\"\\t\", compression=\"gzip\",\n",
131+
" test_data_dir / \"molecule_net_bbbp.tsv.gz\",\n",
132+
" sep=\"\\t\",\n",
133+
" compression=\"gzip\",\n",
132134
")\n",
133135
"\n",
134136
"# Load LogD dataset for regression\n",
135137
"print(\"Loading LogD dataset for regression...\")\n",
136138
"logd_df = pd.read_csv(\n",
137-
" test_data_dir / \"molecule_net_logd.tsv.gz\", sep=\"\\t\", compression=\"gzip\",\n",
139+
" test_data_dir / \"molecule_net_logd.tsv.gz\",\n",
140+
" sep=\"\\t\",\n",
141+
" compression=\"gzip\",\n",
138142
")\n",
139143
"\n",
140144
"print(f\"BBBP dataset shape: {bbbp_df.shape}\")\n",
@@ -340,7 +344,8 @@
340344
" y_true_class_singleton = y_true[class_singleton_mask]\n",
341345
" y_pred_class_singleton = singleton_predictions[class_singleton_mask]\n",
342346
" singleton_accuracies[cls] = accuracy_score(\n",
343-
" y_true_class_singleton, y_pred_class_singleton,\n",
347+
" y_true_class_singleton,\n",
348+
" y_pred_class_singleton,\n",
344349
" )\n",
345350
" else:\n",
346351
" singleton_accuracies[cls] = (\n",
@@ -415,7 +420,11 @@
415420
" stratify_y = y if len(np.unique(y)) < limit else None\n",
416421
"\n",
417422
" x_train_all, x_test, y_train_all, y_test = train_test_split(\n",
418-
" x, y, test_size=test_size, random_state=random_state, stratify=stratify_y,\n",
423+
" x,\n",
424+
" y,\n",
425+
" test_size=test_size,\n",
426+
" random_state=random_state,\n",
427+
" stratify=stratify_y,\n",
419428
" )\n",
420429
"\n",
421430
" # Use stratification for the second split only if appropriate\n",
@@ -498,7 +507,9 @@
498507
"# Create split conformal predictor\n",
499508
"confidence_level = 0.9\n",
500509
"split_cp_clf = ConformalPredictor(\n",
501-
" base_clf, estimator_type=\"classifier\", confidence_level=confidence_level,\n",
510+
" base_clf,\n",
511+
" estimator_type=\"classifier\",\n",
512+
" confidence_level=confidence_level,\n",
502513
")\n",
503514
"\n",
504515
"# Fit and calibrate\n",
@@ -509,12 +520,14 @@
509520
"# Make predictions\n",
510521
"y_pred_proba_split = split_cp_clf.predict_proba(x_test_clf)\n",
511522
"prediction_sets_split = split_cp_clf.predict_conformal_set(\n",
512-
" x_test_clf, confidence=confidence_level,\n",
523+
" x_test_clf,\n",
524+
" confidence=confidence_level,\n",
513525
")\n",
514526
"\n",
515527
"# Evaluate\n",
516528
"results_split_clf = evaluate_classification_conformal(\n",
517-
" y_test_clf, prediction_sets_split,\n",
529+
" y_test_clf,\n",
530+
" prediction_sets_split,\n",
518531
")\n",
519532
"\n",
520533
"print(f\"\\nSplit Conformal Prediction Results (confidence={confidence_level}):\")\n",
@@ -584,12 +597,14 @@
584597
"# Make predictions\n",
585598
"y_pred_proba_mondrian = mondrian_cp_clf.predict_proba(x_test_clf)\n",
586599
"prediction_sets_mondrian = mondrian_cp_clf.predict_conformal_set(\n",
587-
" x_test_clf, confidence=confidence_level,\n",
600+
" x_test_clf,\n",
601+
" confidence=confidence_level,\n",
588602
")\n",
589603
"\n",
590604
"# Evaluate\n",
591605
"results_mondrian_clf = evaluate_classification_conformal(\n",
592-
" y_test_clf, prediction_sets_mondrian,\n",
606+
" y_test_clf,\n",
607+
" prediction_sets_mondrian,\n",
593608
")\n",
594609
"\n",
595610
"print(f\"\\nMondrian Conformal Prediction Results (confidence={confidence_level}):\")\n",
@@ -663,12 +678,14 @@
663678
"# Make predictions\n",
664679
"y_pred_proba_cross = cross_cp_clf.predict_proba(x_test_clf)\n",
665680
"prediction_sets_cross = cross_cp_clf.predict_conformal_set(\n",
666-
" x_test_clf, confidence=confidence_level,\n",
681+
" x_test_clf,\n",
682+
" confidence=confidence_level,\n",
667683
")\n",
668684
"\n",
669685
"# Evaluate\n",
670686
"results_cross_clf = evaluate_classification_conformal(\n",
671-
" y_test_clf, prediction_sets_cross,\n",
687+
" y_test_clf,\n",
688+
" prediction_sets_cross,\n",
672689
")\n",
673690
"\n",
674691
"print(f\"\\nCross Conformal Prediction Results (confidence={confidence_level}):\")\n",
@@ -796,7 +813,10 @@
796813
"\n",
797814
"# Average set size comparison\n",
798815
"bars2 = axes[0, 1].bar(\n",
799-
" comparison_clf[\"Method\"], comparison_clf[\"Avg Set Size\"], alpha=0.7, color=\"orange\",\n",
816+
" comparison_clf[\"Method\"],\n",
817+
" comparison_clf[\"Avg Set Size\"],\n",
818+
" alpha=0.7,\n",
819+
" color=\"orange\",\n",
800820
")\n",
801821
"axes[0, 1].set_title(\"Average Prediction Set Size\")\n",
802822
"axes[0, 1].set_ylabel(\"Set Size\")\n",
@@ -874,7 +894,7 @@
874894
" fontsize=9,\n",
875895
" )\n",
876896
"# Add values on top of bars for Class 1\n",
877-
"for bar in (bars5):\n",
897+
"for bar in bars5:\n",
878898
" height = bar.get_height()\n",
879899
" if not np.isnan(height):\n",
880900
" axes[1, 1].text(\n",
@@ -996,12 +1016,14 @@
9961016
"for conf_level in confidence_levels:\n",
9971017
" # Use the already trained cross conformal predictor\n",
9981018
" prediction_sets_conf = cross_cp_clf.predict_conformal_set(\n",
999-
" x_test_clf, confidence=conf_level,\n",
1019+
" x_test_clf,\n",
1020+
" confidence=conf_level,\n",
10001021
" )\n",
10011022
"\n",
10021023
" # Evaluate\n",
10031024
" results_conf = evaluate_classification_conformal(\n",
1004-
" y_test_clf, prediction_sets_conf,\n",
1025+
" y_test_clf,\n",
1026+
" prediction_sets_conf,\n",
10051027
" )\n",
10061028
"\n",
10071029
" clf_confidence_results.append(\n",
@@ -1191,7 +1213,9 @@
11911213
"\n",
11921214
"# Create split conformal predictor for regression\n",
11931215
"split_cp_reg = ConformalPredictor(\n",
1194-
" base_reg, estimator_type=\"regressor\", confidence_level=confidence_level,\n",
1216+
" base_reg,\n",
1217+
" estimator_type=\"regressor\",\n",
1218+
" confidence_level=confidence_level,\n",
11951219
")\n",
11961220
"\n",
11971221
"# Fit and calibrate\n",
@@ -1205,7 +1229,9 @@
12051229
"\n",
12061230
"# Evaluate\n",
12071231
"results_split_reg = evaluate_regression_conformal(\n",
1208-
" y_test_reg, y_pred_split_reg, intervals_split,\n",
1232+
" y_test_reg,\n",
1233+
" y_pred_split_reg,\n",
1234+
" intervals_split,\n",
12091235
")\n",
12101236
"\n",
12111237
"print(f\"\\nSplit Conformal Prediction Results (confidence={confidence_level}):\")\n",
@@ -1268,7 +1294,9 @@
12681294
"\n",
12691295
"# Evaluate\n",
12701296
"results_cross_reg = evaluate_regression_conformal(\n",
1271-
" y_test_reg, y_pred_cross_reg, intervals_cross,\n",
1297+
" y_test_reg,\n",
1298+
" y_pred_cross_reg,\n",
1299+
" intervals_cross,\n",
12721300
")\n",
12731301
"\n",
12741302
"print(f\"\\nCross Conformal Prediction Results (confidence={confidence_level}):\")\n",
@@ -1394,7 +1422,10 @@
13941422
"\n",
13951423
"# MAE comparison\n",
13961424
"bars3 = axes[1, 0].bar(\n",
1397-
" comparison_reg[\"Method\"], comparison_reg[\"MAE\"], alpha=0.7, color=\"green\",\n",
1425+
" comparison_reg[\"Method\"],\n",
1426+
" comparison_reg[\"MAE\"],\n",
1427+
" alpha=0.7,\n",
1428+
" color=\"green\",\n",
13981429
")\n",
13991430
"axes[1, 0].set_title(\"Mean Absolute Error\")\n",
14001431
"axes[1, 0].set_ylabel(\"MAE\")\n",
@@ -1412,7 +1443,10 @@
14121443
"\n",
14131444
"# RMSE comparison\n",
14141445
"bars4 = axes[1, 1].bar(\n",
1415-
" comparison_reg[\"Method\"], comparison_reg[\"RMSE\"], alpha=0.7, color=\"red\",\n",
1446+
" comparison_reg[\"Method\"],\n",
1447+
" comparison_reg[\"RMSE\"],\n",
1448+
" alpha=0.7,\n",
1449+
" color=\"red\",\n",
14161450
")\n",
14171451
"axes[1, 1].set_title(\"Root Mean Squared Error\")\n",
14181452
"axes[1, 1].set_ylabel(\"RMSE\")\n",
@@ -1434,7 +1468,10 @@
14341468
"# R² comparison in a separate smaller plot\n",
14351469
"fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n",
14361470
"bars5 = ax.bar(\n",
1437-
" comparison_reg[\"Method\"], comparison_reg[\"R²\"], alpha=0.7, color=\"purple\",\n",
1471+
" comparison_reg[\"Method\"],\n",
1472+
" comparison_reg[\"R²\"],\n",
1473+
" alpha=0.7,\n",
1474+
" color=\"purple\",\n",
14381475
")\n",
14391476
"ax.set_title(\"R² Score\")\n",
14401477
"ax.set_ylabel(\"R²\")\n",
@@ -1520,10 +1557,18 @@
15201557
" label=f\"Prediction Intervals ({confidence_level:.0%})\",\n",
15211558
")\n",
15221559
"axes[0].scatter(\n",
1523-
" range(n_plot), y_test_sorted[:n_plot], alpha=0.7, label=\"True Values\", s=30,\n",
1560+
" range(n_plot),\n",
1561+
" y_test_sorted[:n_plot],\n",
1562+
" alpha=0.7,\n",
1563+
" label=\"True Values\",\n",
1564+
" s=30,\n",
15241565
")\n",
15251566
"axes[0].scatter(\n",
1526-
" range(n_plot), y_pred_split_sorted[:n_plot], alpha=0.7, label=\"Predictions\", s=30,\n",
1567+
" range(n_plot),\n",
1568+
" y_pred_split_sorted[:n_plot],\n",
1569+
" alpha=0.7,\n",
1570+
" label=\"Predictions\",\n",
1571+
" s=30,\n",
15271572
")\n",
15281573
"axes[0].set_title(\"Split Conformal Prediction Intervals\")\n",
15291574
"axes[0].set_xlabel(\"Sample Index\")\n",
@@ -1539,10 +1584,18 @@
15391584
" label=f\"Prediction Intervals ({confidence_level:.0%})\",\n",
15401585
")\n",
15411586
"axes[1].scatter(\n",
1542-
" range(n_plot), y_test_sorted[:n_plot], alpha=0.7, label=\"True Values\", s=30,\n",
1587+
" range(n_plot),\n",
1588+
" y_test_sorted[:n_plot],\n",
1589+
" alpha=0.7,\n",
1590+
" label=\"True Values\",\n",
1591+
" s=30,\n",
15431592
")\n",
15441593
"axes[1].scatter(\n",
1545-
" range(n_plot), y_pred_cross_sorted[:n_plot], alpha=0.7, label=\"Predictions\", s=30,\n",
1594+
" range(n_plot),\n",
1595+
" y_pred_cross_sorted[:n_plot],\n",
1596+
" alpha=0.7,\n",
1597+
" label=\"Predictions\",\n",
1598+
" s=30,\n",
15461599
")\n",
15471600
"axes[1].set_title(\"Cross Conformal Prediction Intervals\")\n",
15481601
"axes[1].set_xlabel(\"Sample Index\")\n",
@@ -1648,7 +1701,9 @@
16481701
"\n",
16491702
" # Evaluate\n",
16501703
" results_conf = evaluate_regression_conformal(\n",
1651-
" y_test_reg, y_pred_cross_reg, intervals_conf,\n",
1704+
" y_test_reg,\n",
1705+
" y_pred_cross_reg,\n",
1706+
" intervals_conf,\n",
16521707
" )\n",
16531708
"\n",
16541709
" reg_confidence_results.append(\n",

0 commit comments

Comments
 (0)