Skip to content

Commit

Permalink
Merge pull request #2 from dsbuddy/training
Browse files Browse the repository at this point in the history
Training
  • Loading branch information
dsbuddy authored Feb 2, 2023
2 parents e580c27 + ffadd60 commit 2ba155e
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def numpy_to_pil(img):
def main():
wandb.init(project="med-seg-diff")
parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', type=float, default=5e-4, help='learning rate')
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
)
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
parser.add_argument('-ic', '--input-channels', type=int, default=1, help='input channels for training (default: 3)')
parser.add_argument('-c', '--channels', type=int, default=3, help='output channels for training (default: 3)')
parser.add_argument('-is', '--image-size', type=int, default=128, help='input image size (default: 128)')
Expand Down Expand Up @@ -57,31 +64,37 @@ def main():
ds,
batch_size=args.batch_size,
shuffle=True)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
data = iter(datal)
for _ in range(args.epochs):
for i in tqdm(range(len(data))):
img, mask = next(data)

for i in tqdm(range(len(data))):
img, mask = next(data)


### SETUP DATA ##
#segmented_imgs = torch.rand(8, 3, 128, 128) # inputs are normalized from 0 to 1
#input_imgs = torch.rand(8, 3, 128, 128)

#print("Seg: {}".format(segmented_imgs.shape))
#print("Inp: {}".format(input_imgs.shape))

#print("Img: {}".format(img.shape))
#print("Mask: {}".format(mask.shape))
### SETUP DATA ##
#segmented_imgs = torch.rand(8, 3, 128, 128) # inputs are normalized from 0 to 1
#input_imgs = torch.rand(8, 3, 128, 128)

#print("Seg: {}".format(segmented_imgs.shape))
#print("Inp: {}".format(input_imgs.shape))

#print("Img: {}".format(img.shape))
#print("Mask: {}".format(mask.shape))

## TRAIN MODEL ##
#loss = diffusion(segmented_imgs, input_imgs)
loss = diffusion(mask, img)
wandb.log({'loss': loss}) # Log loss to wanbd
loss.backward()


## TRAIN MODEL ##
#loss = diffusion(segmented_imgs, input_imgs)
loss = diffusion(mask, img)
wandb.log({'loss': loss}) # Log loss to wandb
loss.backward()
optimizer.step()
optimizer.zero_grad()
# after a lot of training


Expand Down

0 comments on commit 2ba155e

Please sign in to comment.