Skip to content

Commit

Permalink
Add script to print model architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
research4pan committed Apr 4, 2024
1 parent 88cbaaa commit 9bdd4f6
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions scripts/tools/print_model_architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python
#coding=utf-8
import argparse
import sys
from transformers import AutoModel

def parse_argument(sys_argv):
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model_name_or_path", type=str, default='gpt2')
args = parser.parse_args(sys_argv[1:])
return args

def main():
args = parse_argument(sys.argv)
model_name = args.model_name_or_path
model = AutoModel.from_pretrained(model_name)

print(model.config)
print(model)

if __name__ == "__main__":
main()

0 comments on commit 9bdd4f6

Please sign in to comment.