|
128 | 128 | "# Load BBBP dataset for classification\n",
|
129 | 129 | "print(\"Loading BBBP dataset for classification...\")\n",
|
130 | 130 | "bbbp_df = pd.read_csv(\n",
|
131 |
| - " test_data_dir / \"molecule_net_bbbp.tsv.gz\", sep=\"\\t\", compression=\"gzip\",\n", |
| 131 | + " test_data_dir / \"molecule_net_bbbp.tsv.gz\",\n", |
| 132 | + " sep=\"\\t\",\n", |
| 133 | + " compression=\"gzip\",\n", |
132 | 134 | ")\n",
|
133 | 135 | "\n",
|
134 | 136 | "# Load LogD dataset for regression\n",
|
135 | 137 | "print(\"Loading LogD dataset for regression...\")\n",
|
136 | 138 | "logd_df = pd.read_csv(\n",
|
137 |
| - " test_data_dir / \"molecule_net_logd.tsv.gz\", sep=\"\\t\", compression=\"gzip\",\n", |
| 139 | + " test_data_dir / \"molecule_net_logd.tsv.gz\",\n", |
| 140 | + " sep=\"\\t\",\n", |
| 141 | + " compression=\"gzip\",\n", |
138 | 142 | ")\n",
|
139 | 143 | "\n",
|
140 | 144 | "print(f\"BBBP dataset shape: {bbbp_df.shape}\")\n",
|
|
340 | 344 | " y_true_class_singleton = y_true[class_singleton_mask]\n",
|
341 | 345 | " y_pred_class_singleton = singleton_predictions[class_singleton_mask]\n",
|
342 | 346 | " singleton_accuracies[cls] = accuracy_score(\n",
|
343 |
| - " y_true_class_singleton, y_pred_class_singleton,\n", |
| 347 | + " y_true_class_singleton,\n", |
| 348 | + " y_pred_class_singleton,\n", |
344 | 349 | " )\n",
|
345 | 350 | " else:\n",
|
346 | 351 | " singleton_accuracies[cls] = (\n",
|
|
415 | 420 | " stratify_y = y if len(np.unique(y)) < limit else None\n",
|
416 | 421 | "\n",
|
417 | 422 | " x_train_all, x_test, y_train_all, y_test = train_test_split(\n",
|
418 |
| - " x, y, test_size=test_size, random_state=random_state, stratify=stratify_y,\n", |
| 423 | + " x,\n", |
| 424 | + " y,\n", |
| 425 | + " test_size=test_size,\n", |
| 426 | + " random_state=random_state,\n", |
| 427 | + " stratify=stratify_y,\n", |
419 | 428 | " )\n",
|
420 | 429 | "\n",
|
421 | 430 | " # Use stratification for the second split only if appropriate\n",
|
|
498 | 507 | "# Create split conformal predictor\n",
|
499 | 508 | "confidence_level = 0.9\n",
|
500 | 509 | "split_cp_clf = ConformalPredictor(\n",
|
501 |
| - " base_clf, estimator_type=\"classifier\", confidence_level=confidence_level,\n", |
| 510 | + " base_clf,\n", |
| 511 | + " estimator_type=\"classifier\",\n", |
| 512 | + " confidence_level=confidence_level,\n", |
502 | 513 | ")\n",
|
503 | 514 | "\n",
|
504 | 515 | "# Fit and calibrate\n",
|
|
509 | 520 | "# Make predictions\n",
|
510 | 521 | "y_pred_proba_split = split_cp_clf.predict_proba(x_test_clf)\n",
|
511 | 522 | "prediction_sets_split = split_cp_clf.predict_conformal_set(\n",
|
512 |
| - " x_test_clf, confidence=confidence_level,\n", |
| 523 | + " x_test_clf,\n", |
| 524 | + " confidence=confidence_level,\n", |
513 | 525 | ")\n",
|
514 | 526 | "\n",
|
515 | 527 | "# Evaluate\n",
|
516 | 528 | "results_split_clf = evaluate_classification_conformal(\n",
|
517 |
| - " y_test_clf, prediction_sets_split,\n", |
| 529 | + " y_test_clf,\n", |
| 530 | + " prediction_sets_split,\n", |
518 | 531 | ")\n",
|
519 | 532 | "\n",
|
520 | 533 | "print(f\"\\nSplit Conformal Prediction Results (confidence={confidence_level}):\")\n",
|
|
584 | 597 | "# Make predictions\n",
|
585 | 598 | "y_pred_proba_mondrian = mondrian_cp_clf.predict_proba(x_test_clf)\n",
|
586 | 599 | "prediction_sets_mondrian = mondrian_cp_clf.predict_conformal_set(\n",
|
587 |
| - " x_test_clf, confidence=confidence_level,\n", |
| 600 | + " x_test_clf,\n", |
| 601 | + " confidence=confidence_level,\n", |
588 | 602 | ")\n",
|
589 | 603 | "\n",
|
590 | 604 | "# Evaluate\n",
|
591 | 605 | "results_mondrian_clf = evaluate_classification_conformal(\n",
|
592 |
| - " y_test_clf, prediction_sets_mondrian,\n", |
| 606 | + " y_test_clf,\n", |
| 607 | + " prediction_sets_mondrian,\n", |
593 | 608 | ")\n",
|
594 | 609 | "\n",
|
595 | 610 | "print(f\"\\nMondrian Conformal Prediction Results (confidence={confidence_level}):\")\n",
|
|
663 | 678 | "# Make predictions\n",
|
664 | 679 | "y_pred_proba_cross = cross_cp_clf.predict_proba(x_test_clf)\n",
|
665 | 680 | "prediction_sets_cross = cross_cp_clf.predict_conformal_set(\n",
|
666 |
| - " x_test_clf, confidence=confidence_level,\n", |
| 681 | + " x_test_clf,\n", |
| 682 | + " confidence=confidence_level,\n", |
667 | 683 | ")\n",
|
668 | 684 | "\n",
|
669 | 685 | "# Evaluate\n",
|
670 | 686 | "results_cross_clf = evaluate_classification_conformal(\n",
|
671 |
| - " y_test_clf, prediction_sets_cross,\n", |
| 687 | + " y_test_clf,\n", |
| 688 | + " prediction_sets_cross,\n", |
672 | 689 | ")\n",
|
673 | 690 | "\n",
|
674 | 691 | "print(f\"\\nCross Conformal Prediction Results (confidence={confidence_level}):\")\n",
|
|
796 | 813 | "\n",
|
797 | 814 | "# Average set size comparison\n",
|
798 | 815 | "bars2 = axes[0, 1].bar(\n",
|
799 |
| - " comparison_clf[\"Method\"], comparison_clf[\"Avg Set Size\"], alpha=0.7, color=\"orange\",\n", |
| 816 | + " comparison_clf[\"Method\"],\n", |
| 817 | + " comparison_clf[\"Avg Set Size\"],\n", |
| 818 | + " alpha=0.7,\n", |
| 819 | + " color=\"orange\",\n", |
800 | 820 | ")\n",
|
801 | 821 | "axes[0, 1].set_title(\"Average Prediction Set Size\")\n",
|
802 | 822 | "axes[0, 1].set_ylabel(\"Set Size\")\n",
|
|
874 | 894 | " fontsize=9,\n",
|
875 | 895 | " )\n",
|
876 | 896 | "# Add values on top of bars for Class 1\n",
|
877 |
| - "for bar in (bars5):\n", |
| 897 | + "for bar in bars5:\n", |
878 | 898 | " height = bar.get_height()\n",
|
879 | 899 | " if not np.isnan(height):\n",
|
880 | 900 | " axes[1, 1].text(\n",
|
|
996 | 1016 | "for conf_level in confidence_levels:\n",
|
997 | 1017 | " # Use the already trained cross conformal predictor\n",
|
998 | 1018 | " prediction_sets_conf = cross_cp_clf.predict_conformal_set(\n",
|
999 |
| - " x_test_clf, confidence=conf_level,\n", |
| 1019 | + " x_test_clf,\n", |
| 1020 | + " confidence=conf_level,\n", |
1000 | 1021 | " )\n",
|
1001 | 1022 | "\n",
|
1002 | 1023 | " # Evaluate\n",
|
1003 | 1024 | " results_conf = evaluate_classification_conformal(\n",
|
1004 |
| - " y_test_clf, prediction_sets_conf,\n", |
| 1025 | + " y_test_clf,\n", |
| 1026 | + " prediction_sets_conf,\n", |
1005 | 1027 | " )\n",
|
1006 | 1028 | "\n",
|
1007 | 1029 | " clf_confidence_results.append(\n",
|
|
1191 | 1213 | "\n",
|
1192 | 1214 | "# Create split conformal predictor for regression\n",
|
1193 | 1215 | "split_cp_reg = ConformalPredictor(\n",
|
1194 |
| - " base_reg, estimator_type=\"regressor\", confidence_level=confidence_level,\n", |
| 1216 | + " base_reg,\n", |
| 1217 | + " estimator_type=\"regressor\",\n", |
| 1218 | + " confidence_level=confidence_level,\n", |
1195 | 1219 | ")\n",
|
1196 | 1220 | "\n",
|
1197 | 1221 | "# Fit and calibrate\n",
|
|
1205 | 1229 | "\n",
|
1206 | 1230 | "# Evaluate\n",
|
1207 | 1231 | "results_split_reg = evaluate_regression_conformal(\n",
|
1208 |
| - " y_test_reg, y_pred_split_reg, intervals_split,\n", |
| 1232 | + " y_test_reg,\n", |
| 1233 | + " y_pred_split_reg,\n", |
| 1234 | + " intervals_split,\n", |
1209 | 1235 | ")\n",
|
1210 | 1236 | "\n",
|
1211 | 1237 | "print(f\"\\nSplit Conformal Prediction Results (confidence={confidence_level}):\")\n",
|
|
1268 | 1294 | "\n",
|
1269 | 1295 | "# Evaluate\n",
|
1270 | 1296 | "results_cross_reg = evaluate_regression_conformal(\n",
|
1271 |
| - " y_test_reg, y_pred_cross_reg, intervals_cross,\n", |
| 1297 | + " y_test_reg,\n", |
| 1298 | + " y_pred_cross_reg,\n", |
| 1299 | + " intervals_cross,\n", |
1272 | 1300 | ")\n",
|
1273 | 1301 | "\n",
|
1274 | 1302 | "print(f\"\\nCross Conformal Prediction Results (confidence={confidence_level}):\")\n",
|
|
1394 | 1422 | "\n",
|
1395 | 1423 | "# MAE comparison\n",
|
1396 | 1424 | "bars3 = axes[1, 0].bar(\n",
|
1397 |
| - " comparison_reg[\"Method\"], comparison_reg[\"MAE\"], alpha=0.7, color=\"green\",\n", |
| 1425 | + " comparison_reg[\"Method\"],\n", |
| 1426 | + " comparison_reg[\"MAE\"],\n", |
| 1427 | + " alpha=0.7,\n", |
| 1428 | + " color=\"green\",\n", |
1398 | 1429 | ")\n",
|
1399 | 1430 | "axes[1, 0].set_title(\"Mean Absolute Error\")\n",
|
1400 | 1431 | "axes[1, 0].set_ylabel(\"MAE\")\n",
|
|
1412 | 1443 | "\n",
|
1413 | 1444 | "# RMSE comparison\n",
|
1414 | 1445 | "bars4 = axes[1, 1].bar(\n",
|
1415 |
| - " comparison_reg[\"Method\"], comparison_reg[\"RMSE\"], alpha=0.7, color=\"red\",\n", |
| 1446 | + " comparison_reg[\"Method\"],\n", |
| 1447 | + " comparison_reg[\"RMSE\"],\n", |
| 1448 | + " alpha=0.7,\n", |
| 1449 | + " color=\"red\",\n", |
1416 | 1450 | ")\n",
|
1417 | 1451 | "axes[1, 1].set_title(\"Root Mean Squared Error\")\n",
|
1418 | 1452 | "axes[1, 1].set_ylabel(\"RMSE\")\n",
|
|
1434 | 1468 | "# R² comparison in a separate smaller plot\n",
|
1435 | 1469 | "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n",
|
1436 | 1470 | "bars5 = ax.bar(\n",
|
1437 |
| - " comparison_reg[\"Method\"], comparison_reg[\"R²\"], alpha=0.7, color=\"purple\",\n", |
| 1471 | + " comparison_reg[\"Method\"],\n", |
| 1472 | + " comparison_reg[\"R²\"],\n", |
| 1473 | + " alpha=0.7,\n", |
| 1474 | + " color=\"purple\",\n", |
1438 | 1475 | ")\n",
|
1439 | 1476 | "ax.set_title(\"R² Score\")\n",
|
1440 | 1477 | "ax.set_ylabel(\"R²\")\n",
|
|
1520 | 1557 | " label=f\"Prediction Intervals ({confidence_level:.0%})\",\n",
|
1521 | 1558 | ")\n",
|
1522 | 1559 | "axes[0].scatter(\n",
|
1523 |
| - " range(n_plot), y_test_sorted[:n_plot], alpha=0.7, label=\"True Values\", s=30,\n", |
| 1560 | + " range(n_plot),\n", |
| 1561 | + " y_test_sorted[:n_plot],\n", |
| 1562 | + " alpha=0.7,\n", |
| 1563 | + " label=\"True Values\",\n", |
| 1564 | + " s=30,\n", |
1524 | 1565 | ")\n",
|
1525 | 1566 | "axes[0].scatter(\n",
|
1526 |
| - " range(n_plot), y_pred_split_sorted[:n_plot], alpha=0.7, label=\"Predictions\", s=30,\n", |
| 1567 | + " range(n_plot),\n", |
| 1568 | + " y_pred_split_sorted[:n_plot],\n", |
| 1569 | + " alpha=0.7,\n", |
| 1570 | + " label=\"Predictions\",\n", |
| 1571 | + " s=30,\n", |
1527 | 1572 | ")\n",
|
1528 | 1573 | "axes[0].set_title(\"Split Conformal Prediction Intervals\")\n",
|
1529 | 1574 | "axes[0].set_xlabel(\"Sample Index\")\n",
|
|
1539 | 1584 | " label=f\"Prediction Intervals ({confidence_level:.0%})\",\n",
|
1540 | 1585 | ")\n",
|
1541 | 1586 | "axes[1].scatter(\n",
|
1542 |
| - " range(n_plot), y_test_sorted[:n_plot], alpha=0.7, label=\"True Values\", s=30,\n", |
| 1587 | + " range(n_plot),\n", |
| 1588 | + " y_test_sorted[:n_plot],\n", |
| 1589 | + " alpha=0.7,\n", |
| 1590 | + " label=\"True Values\",\n", |
| 1591 | + " s=30,\n", |
1543 | 1592 | ")\n",
|
1544 | 1593 | "axes[1].scatter(\n",
|
1545 |
| - " range(n_plot), y_pred_cross_sorted[:n_plot], alpha=0.7, label=\"Predictions\", s=30,\n", |
| 1594 | + " range(n_plot),\n", |
| 1595 | + " y_pred_cross_sorted[:n_plot],\n", |
| 1596 | + " alpha=0.7,\n", |
| 1597 | + " label=\"Predictions\",\n", |
| 1598 | + " s=30,\n", |
1546 | 1599 | ")\n",
|
1547 | 1600 | "axes[1].set_title(\"Cross Conformal Prediction Intervals\")\n",
|
1548 | 1601 | "axes[1].set_xlabel(\"Sample Index\")\n",
|
|
1648 | 1701 | "\n",
|
1649 | 1702 | " # Evaluate\n",
|
1650 | 1703 | " results_conf = evaluate_regression_conformal(\n",
|
1651 |
| - " y_test_reg, y_pred_cross_reg, intervals_conf,\n", |
| 1704 | + " y_test_reg,\n", |
| 1705 | + " y_pred_cross_reg,\n", |
| 1706 | + " intervals_conf,\n", |
1652 | 1707 | " )\n",
|
1653 | 1708 | "\n",
|
1654 | 1709 | " reg_confidence_results.append(\n",
|
|
0 commit comments