diff --git a/source/models.py b/source/models.py index 117f568..879ecd1 100644 --- a/source/models.py +++ b/source/models.py @@ -1,5 +1,9 @@ +from constants import ( + PROJECT_DIRECTORY_PATH +) import utils +import os import torch import numpy as np @@ -13,6 +17,8 @@ def __init__( padding_idx: int = None ): super().__init__() + self.device = device + self.filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "cbow", "model.pt") # init embedding layers self.input_embeddings = torch.nn.Embedding( num_embeddings=vocabulary_size, @@ -32,9 +38,11 @@ def __init__( self.input_embeddings.weight.data.uniform_(-0.5, 0.5) self.output_embeddings.weight.data.uniform_(-0.5, 0.5) # send model to device - self.device = device self.to(self.device) + def save(self): + torch.save(self.state_dict(), self.filepath) + def get_embeddings(self) -> np.ndarray: embeddings = self.input_embeddings.weight.cpu().detach().numpy() embeddings = utils.normalize(embeddings, axis=1, keepdims=True)