Skip to content

Commit

Permalink
fix cell outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
laugustyniak committed Nov 14, 2023
1 parent d5b1dff commit ec53a9c
Showing 1 changed file with 140 additions and 18 deletions.
158 changes: 140 additions & 18 deletions nbs/03_training_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -99,6 +99,27 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset mms downloaded and prepared to /root/.cache/huggingface/datasets/Brand24___mms/default/0.2.0/70532fdd01f149ff84a280b7d9cfb661643abf4837b4f0f3aa1128064e870d65. Subsequent calls will reuse this data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5663e74add1d415c853bcb257a15963a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand All @@ -115,19 +136,36 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'mms_dataset' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/app/nbs/03_training_example.ipynb Cell 7\u001b[0m line \u001b[0;36m2\n\u001b[1;32m <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f6d6d732d62656e63686d61726b227d/app/nbs/03_training_example.ipynb#X36sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39m#| eval: false\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f6d6d732d62656e63686d61726b227d/app/nbs/03_training_example.ipynb#X36sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m mms_dataset\u001b[39m.\u001b[39mcolumn_names\n",
"\u001b[0;31mNameError\u001b[0m: name 'mms_dataset' is not defined"
]
"data": {
"text/plain": [
"{'train': ['_id',\n",
" 'text',\n",
" 'label',\n",
" 'original_dataset',\n",
" 'domain',\n",
" 'language',\n",
" 'Family',\n",
" 'Genus',\n",
" 'Definite articles',\n",
" 'Indefinite articles',\n",
" 'Number of cases',\n",
" 'Order of subject, object, verb',\n",
" 'Negative morphemes',\n",
" 'Polar questions',\n",
" 'Position of negative word wrt SOV',\n",
" 'Prefixing vs suffixing',\n",
" 'Coding of nominal plurality',\n",
" 'Grammatical genders',\n",
" 'cleanlab_self_confidence']}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -146,7 +184,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31c9a75e3c5e4e349120702866a1ca1f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/6165262 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#| eval: false\n",
"pl_sm = mms_dataset[\"train\"].filter(lambda x: x[\"language\"] == \"pl\" and x[\"domain\"] == \"social_media\")"
Expand All @@ -163,7 +216,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3728a8e151e94b6392ea27436255e6d7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Filter: 0%| | 0/169576 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#| eval: false\n",
"pl_sm_high_confidence = pl_sm.filter(lambda x: x[\"cleanlab_self_confidence\"] > 0.6)"
Expand All @@ -173,7 +241,18 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"73227"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| eval: false\n",
"len(pl_sm_high_confidence)"
Expand Down Expand Up @@ -203,7 +282,22 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "28e61f6f723044e2bc8824d7f808f63e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/73227 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#| eval: false\n",
"tokenized_dataset = pl_sm_high_confidence.map(tokenize, batched=True, batch_size=512)"
Expand All @@ -213,7 +307,19 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.sso.sso_relationship.weight', 'cls.predictions.decoder.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allegro/herbert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"#| eval: false\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"allegro/herbert-base-cased\", num_labels=3)"
Expand Down Expand Up @@ -241,7 +347,7 @@
"training_args = TrainingArguments(\n",
" output_dir=\"PL_SM_SENT\",\n",
" evaluation_strategy=\"epoch\",\n",
" num_train_epochs=5,\n",
" num_train_epochs=1,\n",
")\n",
"metric = evaluate.load(\"accuracy\")\n",
"\n",
Expand Down Expand Up @@ -272,11 +378,27 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
}
],
"source": [
"#| eval: false\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit ec53a9c

Please sign in to comment.