diff --git a/models/finetuned/gpt3.5/model.pt b/models/finetuned/gpt3.5/model.pt index 3459968..291dad4 100644 Binary files a/models/finetuned/gpt3.5/model.pt and b/models/finetuned/gpt3.5/model.pt differ diff --git a/models/finetuned/gpt4/model.pt b/models/finetuned/gpt4/model.pt index a5c5765..6570fa9 100644 Binary files a/models/finetuned/gpt4/model.pt and b/models/finetuned/gpt4/model.pt differ diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb index bc4798c..c33c8a0 100644 --- a/notebooks/demo.ipynb +++ b/notebooks/demo.ipynb @@ -47,8 +47,8 @@ "GPT_4_DIR = os.path.join(MODEL_DIR, \"finetuned\", \"gpt4\")\n", "\n", "# Checkpoint paths\n", - "CKPT_GPT3_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-19_00-31-07/18/checkpoints/epoch_031.ckpt\" \n", - "CKPT_GPT4_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-20_22-48-11/9/checkpoints/epoch_039.ckpt\"" + "# CKPT_GPT3_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-21_10-50-22/66/checkpoints/epoch_031.ckpt\" \n", + "# CKPT_GPT4_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-21_10-50-22/73/checkpoints/epoch_046.ckpt\"" ] }, { @@ -67,6 +67,15 @@ "metadata": {}, "outputs": [], "source": [ + "def get_state_dict(ckpt_path: str) -> dict:\n", + " \"\"\"\n", + " Load state dict from PyTorch Lightning checkpoint.\n", + " \"\"\"\n", + " checkpoint = torch.load(ckpt_path)\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " state_dict = {k.replace(\"model.\", \"\"): v for k, v in state_dict.items() if \"model\" in k}\n", + " return state_dict\n", + "\n", "def sort_scores(scores: dict) -> dict:\n", " \"\"\"Sort scores by value in descending order.\"\"\"\n", " return {k: v for k, v in sorted(scores.items(), key=lambda x: x[1], reverse=True)}" @@ -148,20 +157,20 @@ "output_type": "stream", "text": [ "Classes probabilities:\n", - "Science: 0.8412251472473145\n", - "Reference: 0.7838622331619263\n", - "Society: 0.5793622732162476\n", - "Kids_and_Teens: 0.4533769190311432\n", - "Arts: 0.44181838631629944\n", - "Computers: 0.4027433693408966\n", - "News: 0.37343618273735046\n", - "Health: 0.3063579201698303\n", - "Business: 0.24803754687309265\n", - "Recreation: 0.19114521145820618\n", - "Sports: 0.0962551087141037\n", - "Home: 0.037589993327856064\n", - "Shopping: 0.028511211276054382\n", - "Games: 0.028346922248601913\n" + "Science: 0.7964304685592651\n", + "Reference: 0.7635273933410645\n", + "Society: 0.5921807289123535\n", + "News: 0.5681739449501038\n", + "Arts: 0.5368830561637878\n", + "Kids_and_Teens: 0.508224606513977\n", + "Computers: 0.3853667974472046\n", + "Business: 0.3469756543636322\n", + "Health: 0.33045274019241333\n", + "Recreation: 0.2595757842063904\n", + "Sports: 0.11109738796949387\n", + "Home: 0.10465463995933533\n", + "Shopping: 0.06366318464279175\n", + "Games: 0.031982336193323135\n" ] } ], @@ -196,20 +205,20 @@ "output_type": "stream", "text": [ "Classes probabilities:\n", - "Science: 0.9493160843849182\n", - "Reference: 0.659612774848938\n", - "Society: 0.49875521659851074\n", - "Business: 0.3018222749233246\n", - "Computers: 0.18622104823589325\n", - "News: 0.09135214984416962\n", - "Recreation: 0.08869662880897522\n", - "Arts: 0.05620817095041275\n", - "Kids_and_Teens: 0.04272356256842613\n", - "Health: 0.03221626207232475\n", - "Sports: 0.01675873063504696\n", - "Home: 0.003997097257524729\n", - "Shopping: 0.002112273359671235\n", - "Games: 0.0012009597849100828\n" + "Science: 0.9800454378128052\n", + "Reference: 0.7734379768371582\n", + "Society: 0.369582861661911\n", + "Business: 0.17371125519275665\n", + "Recreation: 0.12732912600040436\n", + "Computers: 0.0544419139623642\n", + "Health: 0.02717871218919754\n", + "Arts: 0.008378472179174423\n", + "Sports: 0.007091645151376724\n", + "Kids_and_Teens: 0.002522727008908987\n", + "News: 0.00039661736809648573\n", + "Home: 0.000121747434604913\n", + "Shopping: 6.818987458245829e-05\n", + "Games: 1.6320313079631887e-05\n" ] } ],