From a6278295c2a95f1f96b4e7bff05d315ba2021568 Mon Sep 17 00:00:00 2001 From: Gagan Bhatia <49101362+gagan3012@users.noreply.github.com> Date: Fri, 2 Jul 2021 19:33:11 -0400 Subject: [PATCH] Update trainer.py --- keytotext/trainer.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/keytotext/trainer.py b/keytotext/trainer.py index 814e7c9..1e94d17 100644 --- a/keytotext/trainer.py +++ b/keytotext/trainer.py @@ -18,8 +18,6 @@ from pathlib import Path from torchmetrics import Accuracy - - torch.cuda.empty_cache() pl.seed_everything(42) @@ -92,8 +90,8 @@ def __getitem__(self, index: int): class PLDataModule(LightningDataModule): def __init__( self, - train_df:pd.DataFrame, - test_df:pd.DataFrame, + train_df: pd.DataFrame, + test_df: pd.DataFrame, tokenizer: T5Tokenizer, source_max_token_len: int = 512, target_max_token_len: int = 512, @@ -118,7 +116,6 @@ def __init__( self.source_max_token_len = source_max_token_len self.tokenizer = tokenizer - def setup(self, stage=None): self.train_dataset = DataModule( self.train_df, @@ -193,10 +190,11 @@ def training_step(self, batch, batch_size): attention_mask=attention_mask, decoder_attention_mask=labels_attention_mask, labels=labels, + **batch ) - train_acc = self.val_acc(outputs.logits.argmax(1), labels) + acc = self.val_acc(outputs.logits.argmax(1), labels) self.log("train_loss", loss, prog_bar=True, logger=True) - self.log(f"train_acc", train_acc, prog_bar=True,logger=True) + self.log(f"train_acc", acc, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_size): @@ -211,10 +209,11 @@ def validation_step(self, batch, batch_size): attention_mask=attention_mask, decoder_attention_mask=labels_attention_mask, labels=labels, + **batch ) - val_acc = self.val_acc(outputs.logits.argmax(1), labels) + acc = self.val_acc(outputs.logits.argmax(1), labels) self.log("val_loss", loss, prog_bar=True, logger=True) - self.log(f"val_acc", val_acc, prog_bar=True,logger=True) + self.log(f"val_acc", acc, prog_bar=True, logger=True) return loss def test_step(self, batch, batch_size): @@ -229,6 +228,7 @@ def test_step(self, batch, batch_size): attention_mask=attention_mask, decoder_attention_mask=labels_attention_mask, labels=labels, + **batch ) self.log("test_loss", loss, prog_bar=True, logger=True) @@ -239,7 +239,6 @@ def configure_optimizers(self): return AdamW(self.parameters(), lr=0.0001) - class trainer: """ Keytotext model trainer @@ -293,7 +292,7 @@ def train( self.data_module = PLDataModule( train_df=train_df, - test_df= test_df, + test_df=test_df, tokenizer=self.tokenizer, batch_size=batch_size, source_max_token_len=source_max_token_len, @@ -439,7 +438,7 @@ def predict( ] return preds[0] - def upload(self,model_name,hf_username): + def upload(self, model_name, hf_username): hf_password = getpass("Enter your HuggingFace password") token = HfApi().login(username=hf_username, password=hf_password) del hf_password @@ -447,7 +446,7 @@ def upload(self,model_name,hf_username): model_repo = Repository("./model", clone_from=model_url, use_auth_token=token, git_email=f"{hf_username}@users.noreply.huggingface.co", git_user=hf_username) - readme_txt= f""" + readme_txt = f""" --- language: "en" thumbnail: "Keywords to Sentences" @@ -497,12 +496,3 @@ def upload(self,model_name,hf_username): print("Check out your model at:") print(f"https://huggingface.co/{hf_username}/{model_name}") - - - - - - - - -