diff --git a/tutorials/03-advanced/image_captioning/train.py b/tutorials/03-advanced/image_captioning/train.py index 73007637..61cda919 100644 --- a/tutorials/03-advanced/image_captioning/train.py +++ b/tutorials/03-advanced/image_captioning/train.py @@ -72,9 +72,9 @@ def main(args): # Save the model checkpoints if (i+1) % args.save_step == 0: torch.save(decoder.state_dict(), os.path.join( - args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1))) + args.model_path, 'decoder-{}-{}.pth'.format(epoch+1, i+1))) torch.save(encoder.state_dict(), os.path.join( - args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1))) + args.model_path, 'encoder-{}-{}.pth'.format(epoch+1, i+1))) if __name__ == '__main__':