Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow error when loading checkpoints #117

Open
Sal2040 opened this issue Jun 14, 2021 · 0 comments
Open

Tensorflow error when loading checkpoints #117

Sal2040 opened this issue Jun 14, 2021 · 0 comments

Comments

@Sal2040
Copy link

Sal2040 commented Jun 14, 2021

I am trying to rebuild the model using tf.official like so:

import tensorflow as tf
import official
import json

config_dict = json.loads(tf.io.gfile.GFile('/content/drive/MyDrive/TREC-COVID/scibert_scivocab_uncased/bert_config.json').read())

bert_config = official.nlp.bert.configs.BertConfig.from_dict(config_dict)
bert_model = official.nlp.bert.bert_models.get_transformer_encoder(bert_config)

checkpoint = tf.train.Checkpoint(encoder=bert_model)
checkpoint.read('/content/drive/MyDrive/TREC-COVID/scibert_scivocab_uncased/bert_model.ckpt').assert_consumed()

The last line of the code runs into an error:

AssertionError                            Traceback (most recent call last)
<ipython-input-65-2c94af81d21f> in <module>()
      1 checkpoint = tf.train.Checkpoint(encoder=bert_model)
----> 2 checkpoint.read('/content/drive/MyDrive/TREC-COVID/scibert_scivocab_uncased/bert_model.ckpt').assert_consumed()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/util.py in assert_consumed(self)
   1000       raise AssertionError(
   1001           "Some objects had attributes which were not restored:{}".format(
-> 1002               "".join(unused_attribute_strings)))
   1003     for trackable in self._graph_view.list_objects():
   1004       # pylint: disable=protected-access

AssertionError: Some objects had attributes which were not restored:
    <tf.Variable 'word_embeddings/embeddings:0' shape=(31090, 768) dtype=float32, numpy=
array([[ 0.02573318, -0.00267772,  0.01776482, ..., -0.02813556,
        -0.0021598 , -0.02582178],
       [ 0.00280955, -0.01805187,  0.03772264, ...,  0.02741825,
         0.00221546,  0.01261247],
       [ 0.00038168,  0.00612852, -0.02045917, ...,  0.00615935,
        -0.01438048, -0.00059851],
       ...,

Could you please advise where the problem might be?

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant