Skip to content

Commit

Permalink
Merge branch 'main' of github.com:CS-433/ml-project-2-mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
peternutter committed Dec 21, 2023
2 parents 09fce67 + 7cb7c35 commit 05a7255
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 30 deletions.
Binary file modified models/finetuned/gpt3.5/model.pt
Binary file not shown.
Binary file modified models/finetuned/gpt4/model.pt
Binary file not shown.
69 changes: 39 additions & 30 deletions notebooks/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
"GPT_4_DIR = os.path.join(MODEL_DIR, \"finetuned\", \"gpt4\")\n",
"\n",
"# Checkpoint paths\n",
"CKPT_GPT3_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-19_00-31-07/18/checkpoints/epoch_031.ckpt\" \n",
"CKPT_GPT4_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-20_22-48-11/9/checkpoints/epoch_039.ckpt\""
"# CKPT_GPT3_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-21_10-50-22/66/checkpoints/epoch_031.ckpt\" \n",
"# CKPT_GPT4_PATH = \"/Users/jonas-mika/epfl/coursework/projects/ml-project-2-mlp/logs/train/multiruns/2023-12-21_10-50-22/73/checkpoints/epoch_046.ckpt\""
]
},
{
Expand All @@ -67,6 +67,15 @@
"metadata": {},
"outputs": [],
"source": [
"def get_state_dict(ckpt_path: str) -> dict:\n",
" \"\"\"\n",
" Load state dict from PyTorch Lightning checkpoint.\n",
" \"\"\"\n",
" checkpoint = torch.load(ckpt_path)\n",
" state_dict = checkpoint[\"state_dict\"]\n",
" state_dict = {k.replace(\"model.\", \"\"): v for k, v in state_dict.items() if \"model\" in k}\n",
" return state_dict\n",
"\n",
"def sort_scores(scores: dict) -> dict:\n",
" \"\"\"Sort scores by value in descending order.\"\"\"\n",
" return {k: v for k, v in sorted(scores.items(), key=lambda x: x[1], reverse=True)}"
Expand Down Expand Up @@ -148,20 +157,20 @@
"output_type": "stream",
"text": [
"Classes probabilities:\n",
"Science: 0.8412251472473145\n",
"Reference: 0.7838622331619263\n",
"Society: 0.5793622732162476\n",
"Kids_and_Teens: 0.4533769190311432\n",
"Arts: 0.44181838631629944\n",
"Computers: 0.4027433693408966\n",
"News: 0.37343618273735046\n",
"Health: 0.3063579201698303\n",
"Business: 0.24803754687309265\n",
"Recreation: 0.19114521145820618\n",
"Sports: 0.0962551087141037\n",
"Home: 0.037589993327856064\n",
"Shopping: 0.028511211276054382\n",
"Games: 0.028346922248601913\n"
"Science: 0.7964304685592651\n",
"Reference: 0.7635273933410645\n",
"Society: 0.5921807289123535\n",
"News: 0.5681739449501038\n",
"Arts: 0.5368830561637878\n",
"Kids_and_Teens: 0.508224606513977\n",
"Computers: 0.3853667974472046\n",
"Business: 0.3469756543636322\n",
"Health: 0.33045274019241333\n",
"Recreation: 0.2595757842063904\n",
"Sports: 0.11109738796949387\n",
"Home: 0.10465463995933533\n",
"Shopping: 0.06366318464279175\n",
"Games: 0.031982336193323135\n"
]
}
],
Expand Down Expand Up @@ -196,20 +205,20 @@
"output_type": "stream",
"text": [
"Classes probabilities:\n",
"Science: 0.9493160843849182\n",
"Reference: 0.659612774848938\n",
"Society: 0.49875521659851074\n",
"Business: 0.3018222749233246\n",
"Computers: 0.18622104823589325\n",
"News: 0.09135214984416962\n",
"Recreation: 0.08869662880897522\n",
"Arts: 0.05620817095041275\n",
"Kids_and_Teens: 0.04272356256842613\n",
"Health: 0.03221626207232475\n",
"Sports: 0.01675873063504696\n",
"Home: 0.003997097257524729\n",
"Shopping: 0.002112273359671235\n",
"Games: 0.0012009597849100828\n"
"Science: 0.9800454378128052\n",
"Reference: 0.7734379768371582\n",
"Society: 0.369582861661911\n",
"Business: 0.17371125519275665\n",
"Recreation: 0.12732912600040436\n",
"Computers: 0.0544419139623642\n",
"Health: 0.02717871218919754\n",
"Arts: 0.008378472179174423\n",
"Sports: 0.007091645151376724\n",
"Kids_and_Teens: 0.002522727008908987\n",
"News: 0.00039661736809648573\n",
"Home: 0.000121747434604913\n",
"Shopping: 6.818987458245829e-05\n",
"Games: 1.6320313079631887e-05\n"
]
}
],
Expand Down

0 comments on commit 05a7255

Please sign in to comment.