We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 78abb38 + aab5b7d commit e292773Copy full SHA for e292773
src/ecco/lm.py
@@ -86,7 +86,10 @@ def __init__(self,
86
self.model_type = self.model_config['type']
87
embeddings_layer_name = self.model_config['embedding']
88
embed_retriever = attrgetter(embeddings_layer_name)
89
- self.model_embeddings = embed_retriever(self.model)
+ if type(embed_retriever(self.model)) == torch.nn.Embedding:
90
+ self.model_embeddings = embed_retriever(self.model).weight
91
+ else:
92
+ self.model_embeddings = embed_retriever(self.model)
93
self.collect_activations_layer_name_sig = self.model_config['activations'][0]
94
except KeyError:
95
raise ValueError(
0 commit comments