Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

M2 #879

Merged
merged 5 commits into from
Feb 11, 2024
Merged

M2 #879

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 130 additions & 130 deletions nbs/006_data.core.ipynb

Large diffs are not rendered by default.

422 changes: 280 additions & 142 deletions nbs/010_data.transforms.ipynb

Large diffs are not rendered by default.

55 changes: 28 additions & 27 deletions nbs/012_data.image.ipynb

Large diffs are not rendered by default.

189 changes: 92 additions & 97 deletions nbs/022_tslearner.ipynb

Large diffs are not rendered by default.

76 changes: 39 additions & 37 deletions nbs/026_callback.noisy_student.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|export \n",
"#|export\n",
"from tsai.imports import *\n",
"from tsai.utils import *\n",
"from tsai.data.preprocessing import *\n",
Expand All @@ -61,26 +61,26 @@
"#|export\n",
"\n",
"# This is an unofficial implementation of noisy student based on:\n",
"# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification. \n",
"# Xie, Q., Luong, M. T., Hovy, E., & Le, Q. V. (2020). Self-training with noisy student improves imagenet classification.\n",
"# In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10687-10698).\n",
"# Official tensorflow implementation available in https://github.com/google-research/noisystudent\n",
"\n",
"\n",
"class NoisyStudent(Callback):\n",
" \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise: \n",
" \"\"\"A callback to implement the Noisy Student approach. In the original paper this was used in combination with noise:\n",
" - stochastic depth: .8\n",
" - RandAugment: N=2, M=27\n",
" - dropout: .5\n",
" \n",
"\n",
" Steps:\n",
" 1. Build the dl you will use as a teacher\n",
" 2. Create dl2 with the pseudolabels (either soft or hard preds)\n",
" 3. Pass any required batch_tfms to the callback\n",
" \n",
"\n",
" \"\"\"\n",
" \n",
" def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True, \n",
" pseudolabel_sample_weight:float=1., verbose=False): \n",
"\n",
" def __init__(self, dl2:DataLoader, bs:Optional[int]=None, l2pl_ratio:int=1, batch_tfms:Optional[list]=None, do_setup:bool=True,\n",
" pseudolabel_sample_weight:float=1., verbose=False):\n",
" r'''\n",
" Args:\n",
" dl2: dataloader with the pseudolabels\n",
Expand All @@ -90,18 +90,18 @@
" do_setup: perform a transform setup on the labeled dataset.\n",
" pseudolabel_sample_weight: weight of each pseudolabel sample relative to the labeled one of the loss.\n",
" '''\n",
" \n",
"\n",
" self.dl2, self.bs, self.l2pl_ratio, self.batch_tfms, self.do_setup, self.verbose = dl2, bs, l2pl_ratio, batch_tfms, do_setup, verbose\n",
" self.pl_sw = pseudolabel_sample_weight\n",
" \n",
"\n",
" def before_fit(self):\n",
" if self.batch_tfms is None: self.batch_tfms = self.dls.train.after_batch\n",
" self.old_bt = self.dls.train.after_batch # Remove and store dl.train.batch_tfms\n",
" self.old_bs = self.dls.train.bs\n",
" self.dls.train.after_batch = noop \n",
" self.dls.train.after_batch = noop\n",
"\n",
" if self.do_setup and self.batch_tfms:\n",
" for bt in self.batch_tfms: \n",
" for bt in self.batch_tfms:\n",
" bt.setup(self.dls.train)\n",
"\n",
" if self.bs is None: self.bs = self.dls.train.bs\n",
Expand All @@ -111,12 +111,12 @@
" pv(f'labels / pseudolabels per training batch : {self.dls.train.bs} / {self.dl2.bs}', self.verbose)\n",
" rel_weight = (self.dls.train.bs/self.dl2.bs) * (len(self.dl2.dataset)/len(self.dls.train.dataset))\n",
" pv(f'relative labeled/ pseudolabel sample weight in dataset: {rel_weight:.1f}', self.verbose)\n",
" \n",
"\n",
" self.dl2iter = iter(self.dl2)\n",
" \n",
"\n",
" self.old_loss_func = self.learn.loss_func\n",
" self.learn.loss_func = self.loss\n",
" \n",
"\n",
" def before_batch(self):\n",
" if self.training:\n",
" X, y = self.x, self.y\n",
Expand All @@ -125,26 +125,26 @@
" self.dl2iter = iter(self.dl2)\n",
" X2, y2 = next(self.dl2iter)\n",
" if y.ndim == 1 and y2.ndim == 2: y = torch.eye(self.learn.dls.c, device=y.device)[y]\n",
" \n",
"\n",
" X_comb, y_comb = concat(X, X2), concat(y, y2)\n",
" \n",
" if self.batch_tfms is not None: \n",
"\n",
" if self.batch_tfms is not None:\n",
" X_comb = compose_tfms(X_comb, self.batch_tfms, split_idx=0)\n",
" y_comb = compose_tfms(y_comb, self.batch_tfms, split_idx=0)\n",
" self.learn.xb = (X_comb,)\n",
" self.learn.yb = (y_comb,)\n",
" pv(f'\\nX: {X.shape} X2: {X2.shape} X_comb: {X_comb.shape}', self.verbose)\n",
" pv(f'y: {y.shape} y2: {y2.shape} y_comb: {y_comb.shape}', self.verbose)\n",
" \n",
" def loss(self, output, target): \n",
"\n",
" def loss(self, output, target):\n",
" if target.ndim == 2: _, target = target.max(dim=1)\n",
" if self.training and self.pl_sw != 1: \n",
" if self.training and self.pl_sw != 1:\n",
" loss = (1 - self.pl_sw) * self.old_loss_func(output[:self.dls.train.bs], target[:self.dls.train.bs])\n",
" loss += self.pl_sw * self.old_loss_func(output[self.dls.train.bs:], target[self.dls.train.bs:])\n",
" return loss \n",
" else: \n",
" return loss\n",
" else:\n",
" return self.old_loss_func(output, target)\n",
" \n",
"\n",
" def after_fit(self):\n",
" self.dls.train.after_batch = self.old_bt\n",
" self.learn.loss_func = self.old_loss_func\n",
Expand All @@ -170,7 +170,8 @@
"outputs": [],
"source": [
"dsid = 'NATOPS'\n",
"X, y, splits = get_UCR_data(dsid, return_split=False)"
"X, y, splits = get_UCR_data(dsid, return_split=False)\n",
"X = X.astype(np.float32)"
]
},
{
Expand Down Expand Up @@ -229,10 +230,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.884984</td>\n",
" <td>1.809759</td>\n",
" <td>0.166667</td>\n",
" <td>00:06</td>\n",
" <td>1.782144</td>\n",
" <td>1.758471</td>\n",
" <td>0.250000</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -249,7 +250,7 @@
"output_type": "stream",
"text": [
"\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 58])\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 41])\n",
"y: torch.Size([171]) y2: torch.Size([85]) y_comb: torch.Size([256])\n"
]
}
Expand Down Expand Up @@ -323,10 +324,10 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.894964</td>\n",
" <td>1.814770</td>\n",
" <td>0.177778</td>\n",
" <td>00:03</td>\n",
" <td>1.898401</td>\n",
" <td>1.841182</td>\n",
" <td>0.155556</td>\n",
" <td>00:00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -343,7 +344,7 @@
"output_type": "stream",
"text": [
"\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 45])\n",
"X: torch.Size([171, 24, 51]) X2: torch.Size([85, 24, 51]) X_comb: torch.Size([256, 24, 51])\n",
"y: torch.Size([171, 6]) y2: torch.Size([85, 6]) y_comb: torch.Size([256, 6])\n"
]
}
Expand All @@ -353,6 +354,7 @@
"soft_preds = False\n",
"\n",
"pseudolabels = ToNumpyCategory()(y) if soft_preds else OneHot()(y)\n",
"pseudolabels = pseudolabels.astype(np.float32)\n",
"dsets2 = TSDatasets(pseudolabeled_data, pseudolabels)\n",
"dl2 = TSDataLoader(dsets2, num_workers=0)\n",
"noisy_student_cb = NoisyStudent(dl2, bs=256, l2pl_ratio=2, verbose=True)\n",
Expand Down Expand Up @@ -380,9 +382,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2023-01-21 14:30:23\n",
"/Users/nacho/notebooks/tsai/nbs/026_callback.noisy_student.ipynb saved at 2024-02-10 21:53:24\n",
"Correct notebook to script conversion! 😃\n",
"Saturday 21/01/23 14:30:25 CET\n"
"Saturday 10/02/24 21:53:27 CET\n"
]
},
{
Expand Down
Loading
Loading