diff --git "a/day5/\346\274\224\347\277\2222/main.py" "b/day5/\346\274\224\347\277\2222/main.py" index 776b70e75..2b256b6c5 100644 --- "a/day5/\346\274\224\347\277\2222/main.py" +++ "b/day5/\346\274\224\347\277\2222/main.py" @@ -11,6 +11,7 @@ import time import great_expectations as gx + class DataLoader: """データロードを行うクラス""" diff --git "a/day5/\346\274\224\347\277\2222/models/titanic_model.pkl" "b/day5/\346\274\224\347\277\2222/models/titanic_model.pkl" index 9e1859fdf..02a80d86e 100644 Binary files "a/day5/\346\274\224\347\277\2222/models/titanic_model.pkl" and "b/day5/\346\274\224\347\277\2222/models/titanic_model.pkl" differ diff --git "a/day5/\346\274\224\347\277\2223/tests/test_model.py" "b/day5/\346\274\224\347\277\2223/tests/test_model.py" index e11a19a5c..bb8bdf5de 100644 --- "a/day5/\346\274\224\347\277\2223/tests/test_model.py" +++ "b/day5/\346\274\224\347\277\2223/tests/test_model.py" @@ -11,6 +11,8 @@ from sklearn.preprocessing import OneHotEncoder, StandardScaler from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline +import json + # テスト用データとモデルパスを定義 DATA_PATH = os.path.join(os.path.dirname(__file__), "../data/Titanic.csv") @@ -171,3 +173,39 @@ def test_model_reproducibility(sample_data, preprocessor): assert np.array_equal( predictions1, predictions2 ), "モデルの予測結果に再現性がありません" + + +######### 案1 + +BASELINE_PATH = os.path.join(os.path.dirname(__file__), "baseline_metrics.json") + + +def load_baseline(): + if os.path.exists(BASELINE_PATH): + with open(BASELINE_PATH, "r") as f: + return json.load(f) + else: + return None + + +def save_baseline(metrics): + with open(BASELINE_PATH, "w") as f: + json.dump(metrics, f, indent=2) + + +def test_accuracy_regression(train_model): + """モデル精度の劣化チェック""" + model, X_test, y_test = train_model + baseline = load_baseline() + + y_pred = model.predict(X_test) + accuracy = accuracy_score(y_test, y_pred) + + if baseline is None: + # 初回はベースラインを保存してスキップ + save_baseline({"accuracy": accuracy}) + pytest.skip("ベースライン精度を保存しました。次回から劣化チェックを行います。") + + assert ( + accuracy >= baseline["accuracy"] - 0.02 + ), f"精度がベースラインより低下しています: {accuracy} < {baseline['accuracy']}"