Skip to content

Commit

Permalink
Add save method to ModelCBOW class
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 15, 2023
1 parent 231387b commit 6f11fc9
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion source/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from constants import (
PROJECT_DIRECTORY_PATH
)
import utils

import os
import torch
import numpy as np

Expand All @@ -13,6 +17,8 @@ def __init__(
padding_idx: int = None
):
super().__init__()
self.device = device
self.filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "cbow", "model.pt")
# init embedding layers
self.input_embeddings = torch.nn.Embedding(
num_embeddings=vocabulary_size,
Expand All @@ -32,9 +38,11 @@ def __init__(
self.input_embeddings.weight.data.uniform_(-0.5, 0.5)
self.output_embeddings.weight.data.uniform_(-0.5, 0.5)
# send model to device
self.device = device
self.to(self.device)

def save(self):
torch.save(self.state_dict(), self.filepath)

def get_embeddings(self) -> np.ndarray:
embeddings = self.input_embeddings.weight.cpu().detach().numpy()
embeddings = utils.normalize(embeddings, axis=1, keepdims=True)
Expand Down

0 comments on commit 6f11fc9

Please sign in to comment.