Skip to content

Commit

Permalink
Unnecesary arguments removed
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksii Prykhodko committed May 29, 2019
1 parent a20de0d commit 91e71ff
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 11 deletions.
2 changes: 1 addition & 1 deletion create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def create_model():

parser.add_argument("--input-data-path", "-i", help="The path to a data file.", type=str, required=True)
parser.add_argument("--output-model-folder", "-o", help="Prefix to the folder to save output model.", type=str)
parser.add_argument("--latent_dim", "-ld", help="dimensionality of the latent space", type=int)
parser.add_argument("--latent_dim", "-ld", help="dimensionality of the noise", type=int)
args = {k: v for k, v in vars(parser.parse_args()).items() if v is not None}

runner = CreateModelRunner(**args)
Expand Down
7 changes: 2 additions & 5 deletions runners/TrainModelRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TrainModelRunner:
lambda_gp = 10

def __init__(self, input_data_path, output_model_folder, decode_mols_save_path='', n_epochs=200, starting_epoch=1,
batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=4, n_critic=5, clip_value=0.01, sample_interval=10,
batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_critic=5, sample_interval=10,
save_interval=100, sample_after_training=100, message=""):
self.message = message

Expand All @@ -31,9 +31,7 @@ def __init__(self, input_data_path, output_model_folder, decode_mols_save_path='
self.lr = lr
self.b1 = b1
self.b2 = b2
self.n_cpu = n_cpu
self.n_critic = n_critic
self.clip_value = clip_value
self.sample_interval = sample_interval
self.save_interval = save_interval
self.sample_after_training = sample_after_training
Expand Down Expand Up @@ -160,19 +158,18 @@ def run(self):
with open(os.path.join(self.output_model_folder, 'gen_loss.json'), 'w') as json_file:
json.dump(g_loss_log, json_file)

# Sampling after training
if self.sample_after_training > 0:
# sampling mode
torch.no_grad()
self.G.eval()

S = Sampler(generator=self.G)
latent = S.sample(self.sample_after_training)
# latent = ((latent + 1) * self.peak / 2) + self.min
latent = latent.detach().cpu().numpy().tolist()

sampled_mols_save_path = os.path.join(self.output_model_folder, 'sampled.json')
with open(sampled_mols_save_path, 'w') as json_file:
# array_fake_mols = fake_mols.data
json.dump(latent, json_file)

# decoding sampled mols
Expand Down
6 changes: 1 addition & 5 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@ def train_model():
parser.add_argument("--lr", type=float, help="adam: learning rate")
parser.add_argument("--b1", type=float, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n-cpu", type=int, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent-dim", type=int, help="dimensionality of the latent space")
parser.add_argument("--img-size", type=int, help="size of each image dimension")
parser.add_argument("--channels", type=int, help="number of image channels")
parser.add_argument("--n-critic", type=int, help="number of training steps for discriminator per iter")
parser.add_argument("--clip-value", type=float, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample-interval", type=int, help="interval between samples")
parser.add_argument("--save-interval", type=int, help="interval between saving the model")
parser.add_argument("--sample-after-training", type=int, help="Number of molecules to sample after training")
parser.add_argument("--message", "-m", type=str, help="Number of molecules to sample after training")
parser.add_argument("--message", "-m", type=str, help="The message to print before the training starts")

args = {k: v for k, v in vars(parser.parse_args()).items() if v is not None}

Expand Down

0 comments on commit 91e71ff

Please sign in to comment.