Skip to content

Commit

Permalink
Update trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gagan3012 committed Jul 2, 2021
1 parent dc6242f commit a627829
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions keytotext/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from pathlib import Path
from torchmetrics import Accuracy



torch.cuda.empty_cache()
pl.seed_everything(42)

Expand Down Expand Up @@ -92,8 +90,8 @@ def __getitem__(self, index: int):
class PLDataModule(LightningDataModule):
def __init__(
self,
train_df:pd.DataFrame,
test_df:pd.DataFrame,
train_df: pd.DataFrame,
test_df: pd.DataFrame,
tokenizer: T5Tokenizer,
source_max_token_len: int = 512,
target_max_token_len: int = 512,
Expand All @@ -118,7 +116,6 @@ def __init__(
self.source_max_token_len = source_max_token_len
self.tokenizer = tokenizer


def setup(self, stage=None):
self.train_dataset = DataModule(
self.train_df,
Expand Down Expand Up @@ -193,10 +190,11 @@ def training_step(self, batch, batch_size):
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
**batch
)
train_acc = self.val_acc(outputs.logits.argmax(1), labels)
acc = self.val_acc(outputs.logits.argmax(1), labels)
self.log("train_loss", loss, prog_bar=True, logger=True)
self.log(f"train_acc", train_acc, prog_bar=True,logger=True)
self.log(f"train_acc", acc, prog_bar=True, logger=True)
return loss

def validation_step(self, batch, batch_size):
Expand All @@ -211,10 +209,11 @@ def validation_step(self, batch, batch_size):
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
**batch
)
val_acc = self.val_acc(outputs.logits.argmax(1), labels)
acc = self.val_acc(outputs.logits.argmax(1), labels)
self.log("val_loss", loss, prog_bar=True, logger=True)
self.log(f"val_acc", val_acc, prog_bar=True,logger=True)
self.log(f"val_acc", acc, prog_bar=True, logger=True)
return loss

def test_step(self, batch, batch_size):
Expand All @@ -229,6 +228,7 @@ def test_step(self, batch, batch_size):
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
**batch
)

self.log("test_loss", loss, prog_bar=True, logger=True)
Expand All @@ -239,7 +239,6 @@ def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)



class trainer:
"""
Keytotext model trainer
Expand Down Expand Up @@ -293,7 +292,7 @@ def train(

self.data_module = PLDataModule(
train_df=train_df,
test_df= test_df,
test_df=test_df,
tokenizer=self.tokenizer,
batch_size=batch_size,
source_max_token_len=source_max_token_len,
Expand Down Expand Up @@ -439,15 +438,15 @@ def predict(
]
return preds[0]

def upload(self,model_name,hf_username):
def upload(self, model_name, hf_username):
hf_password = getpass("Enter your HuggingFace password")
token = HfApi().login(username=hf_username, password=hf_password)
del hf_password
model_url = HfApi().create_repo(token=token, name=model_name, exist_ok=True)
model_repo = Repository("./model", clone_from=model_url, use_auth_token=token,
git_email=f"{hf_username}@users.noreply.huggingface.co", git_user=hf_username)

readme_txt= f"""
readme_txt = f"""
---
language: "en"
thumbnail: "Keywords to Sentences"
Expand Down Expand Up @@ -497,12 +496,3 @@ def upload(self,model_name,hf_username):

print("Check out your model at:")
print(f"https://huggingface.co/{hf_username}/{model_name}")









0 comments on commit a627829

Please sign in to comment.