Skip to content

Commit

Permalink
made input shapes (None, None, 3)
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Oct 23, 2022
1 parent 8990a26 commit 6364fd2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions convert_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ def main(args):
dataset_name = args.ckpt_path.split("/")[-2].lower()
tf_params_path = f"{variant}_{task.lower()}_{dataset_name}.h5"

tf_model.save_weights(tf_params_path)
print(f"Model params serialized to {tf_params_path}.")
# tf_model.save_weights(tf_params_path)
# print(f"Model params serialized to {tf_params_path}.")
saved_model_path = tf_params_path.replace(".h5", "")
tf_model.save(saved_model_path)
print(f"SavedModel serialized to {saved_model_path}.")
push_to_hub_keras(tf_model, repo_path_or_name=f"sayakpaul/{saved_model_path}")
print("Model pushed to Hugging Face Hub.")
# push_to_hub_keras(tf_model, repo_path_or_name=f"sayakpaul/{saved_model_path}")
# print("Model pushed to Hugging Face Hub.")


def parse_args():
Expand Down
2 changes: 1 addition & 1 deletion create_maxim_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from maxim.configs import MAXIM_CONFIGS


def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model:
def Model(variant=None, input_resolution=(None, None), **kw) -> keras.Model:
"""Factory function to easily create a Model variant like "S".
Args:
Expand Down

0 comments on commit 6364fd2

Please sign in to comment.