Skip to content

Commit

Permalink
Added notebook outlining inference on new data
Browse files Browse the repository at this point in the history
  • Loading branch information
mkranzlein committed Feb 18, 2024
1 parent 9fc9aea commit f467138
Showing 1 changed file with 147 additions and 0 deletions.
147 changes: 147 additions & 0 deletions scripts/notebooks/model_inference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%cd -q ../.."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"import json\n",
"\n",
"import torch\n",
"from transformers import BertTokenizerFast\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print('Using device:', device)\n",
"\n",
"# Load model and tokenizer\n",
"sentence_model = torch.load(\"models/curiam/sentence_level_model_nohipool.pt\")\n",
"token_model = torch.load(\"models/curiam/working_model_nohipool.pt\")\n",
"bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"with open(\"data/curiam.json\", \"r\", encoding=\"utf-8\") as f:\n",
" json_data = json.load(f)\n",
"\n",
"# Each document is a list of sentences, and each sentence is a list of tokens.\n",
"documents = []\n",
"\n",
"# labels[i] is an [n, k] tensor where n is the number of tokens in the i-th sentence and\n",
"# k is the number of binary labels assigned to each token.\n",
"\n",
"for raw_document in json_data:\n",
" doc_sentences = [[token[\"text\"].lower() for token in sentence[\"tokens\"]]\n",
" for sentence in raw_document[\"sentences\"]]\n",
" documents.append(doc_sentences)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sentence: This is a sentence\n",
"Token FT MC DQ LeS \n",
"[CLS] N N N N \n",
"this N N N N \n",
"is N N N N \n",
"a N N N N \n",
"sentence N N N N \n",
"[SEP] N N N N \n"
]
}
],
"source": [
"def predict_sentence_toks(sentence: list[str]):\n",
" y = bert_tokenizer(sentence, is_split_into_words=True, return_attention_mask=True, return_token_type_ids=True, add_special_tokens=True, return_tensors=\"pt\")\n",
" output = token_model(y[\"input_ids\"].cuda(), mask=y[\"attention_mask\"].cuda(), token_type_ids=y[\"token_type_ids\"].cuda())\n",
" sigmoid_outputs = torch.nn.functional.sigmoid(output)\n",
" print(\"Sentence:\", \" \".join(sentence))\n",
" print(f\"{'Token':<20}{'FT':<4}{'MC':<4}{'DQ':<4}{'LeS':<4}\")\n",
" for token, preds in zip(bert_tokenizer.convert_ids_to_tokens(y[\"input_ids\"][0]), sigmoid_outputs[0]):\n",
" line = [token]\n",
" for pred in preds:\n",
" if pred > .5:\n",
" line.append(\"Y\")\n",
" else:\n",
" line.append(\"N\")\n",
" print(f\"{line[0]:<20}{line[1]:<4}{line[2]:<4}{line[3]:<4}{line[4]:<4}\")\n",
"\n",
"sample_sentence = [\"This\", \"is\", \"a\", \"sentence\"]\n",
"predict_sentence_toks(sample_sentence)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: fix output alignment like in previous func\n",
"def predict_meta_sentence(sample):\n",
" y = bert_tokenizer(sample, is_split_into_words=True, return_attention_mask=True, return_token_type_ids=True, add_special_tokens=True, return_tensors=\"pt\")\n",
" y = bert_tokenizer(sample, is_split_into_words=True, return_attention_mask=True, return_token_type_ids=True, add_special_tokens=True, return_tensors=\"pt\")\n",
" output = sentence_model(y[\"input_ids\"].cuda(), mask=y[\"attention_mask\"].cuda(), token_type_ids=y[\"token_type_ids\"].cuda())\n",
" sigmoid_outputs = torch.nn.functional.sigmoid(output)\n",
" print(' '.join(sample))\n",
" print('FT\\tMC\\tDQ\\tLeS')\n",
" line_out = \"\"\n",
" for pred in sigmoid_outputs[0]:\n",
" if pred >=.5:\n",
" line_out += f\"Y\\t\"\n",
" else:\n",
" line_out += f\"N\\t\"\n",
" print(line_out)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hipool",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit f467138

Please sign in to comment.