Skip to content

Commit

Permalink
fix detach tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
HessTaha committed May 31, 2020
1 parent e1cd7a8 commit 4f5770c
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -33,7 +34,7 @@ def main(path_to_data: str,
optimal_length = get_length(df, texts_col)
X, vocab_size = encode_texts(df, texts_col, max_seq_length=optimal_length, return_vocab_size=True)

y = get_labels(df, labels_col)
y = get_labels(df, labels_col, n_classes)

train_loader, test_loader = create_TorchLoaders(X, y, test_size=0.10, batch_size=batch_size, batch_size_eval=batch_size_eval)

Expand Down Expand Up @@ -86,15 +87,12 @@ def main(path_to_data: str,
labels = labels.to(device, dtype=torch.float)

preds = Model(inputs)

loss = criterion(preds, labels)

## Metrics computation


metrics["training_loss"].append(loss.item())

tmp_f1 = f1_score(labels.to("cpu").numpy(), preds.to("cpu").numpy())
tmp_f1 = f1_score(labels.to("cpu").detach().numpy(), preds.to("cpu").detach().numpy())

metrics["training_f1"].append(tmp_f1)

Expand All @@ -119,13 +117,12 @@ def main(path_to_data: str,
labels = labels.to(device, dtype=torch.float)

preds = Model(inputs)

eval_loss = criterion(preds, labels)

## Eval metrics
metrics["eval_loss"].append(eval_loss.item())

tmp_f1 = f1_score(labels.to("cpu").numpy(), preds.to("cpu").numpy()) ## detach
tmp_f1 = f1_score(labels.to("cpu").detach().numpy(), preds.to("cpu").detach().numpy()) ## detach

metrics["eval_f1"].append(tmp_f1)

Expand Down

0 comments on commit 4f5770c

Please sign in to comment.