Skip to content

Commit

Permalink
Fix logger append.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcamino committed Aug 20, 2018
1 parent 5e6146d commit 8cbe913
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion multi_categorical_gans/methods/arae/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def train(autoencoder,
optim_gen = Adam(generator.parameters(), weight_decay=l2_regularization, lr=learning_rate)
optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate)

logger = Logger(output_loss_path)
logger = Logger(output_loss_path, append=start_epoch > 0)

for epoch_index in range(start_epoch, num_epochs):
logger.start_timer()
Expand Down
2 changes: 1 addition & 1 deletion multi_categorical_gans/methods/mc_gumbel/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train(generator,

criterion = BCELoss()

logger = Logger(output_loss_path)
logger = Logger(output_loss_path, append=start_epoch > 0)

for epoch_index in range(start_epoch, num_epochs):
logger.start_timer()
Expand Down
2 changes: 1 addition & 1 deletion multi_categorical_gans/methods/mc_wgan_gp/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def train(generator,
optim_gen = Adam(generator.parameters(), weight_decay=l2_regularization, lr=learning_rate)
optim_disc = Adam(discriminator.parameters(), weight_decay=l2_regularization, lr=learning_rate)

logger = Logger(output_loss_path)
logger = Logger(output_loss_path, append=start_epoch > 0)

for epoch_index in range(start_epoch, num_epochs):
logger.start_timer()
Expand Down
2 changes: 1 addition & 1 deletion multi_categorical_gans/methods/medgan/pre_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def pre_train(autoencoder,

optim = Adam(autoencoder.parameters(), weight_decay=l2_regularization, lr=learning_rate)

logger = Logger(output_loss_path)
logger = Logger(output_loss_path, append=start_epoch > 0)

for epoch_index in range(start_epoch, num_epochs):
logger.start_timer()
Expand Down
2 changes: 1 addition & 1 deletion multi_categorical_gans/methods/medgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train(autoencoder,

criterion = BCELoss()

logger = Logger(output_loss_path)
logger = Logger(output_loss_path, append=start_epoch > 0)

for epoch_index in range(start_epoch, num_epochs):
logger.start_timer()
Expand Down
4 changes: 2 additions & 2 deletions multi_categorical_gans/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class Logger(object):

start_time = None

def __init__(self, output_path):
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
def __init__(self, output_path, append=False):
if append and os.path.exists(output_path) and os.path.getsize(output_path) > 0:
self.output_file = open(output_path, "a")
self.output_writer = csv.DictWriter(self.output_file, fieldnames=self.CSV_COLUMNS)
else:
Expand Down

0 comments on commit 8cbe913

Please sign in to comment.