Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
Fix the bug #1 (temporary)
  • Loading branch information
sungyoon-lee authored Jul 11, 2022
1 parent b805939 commit fc7eb63
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def Train_calloss(model, t, i, data, labels, loader, eps,end_eps, max_eps, norm,
data_u = np.inf
else:
data_l = data_u = None
f = DualNetwork(model, data, convex_eps, proj = proj, norm_type = norm_type, bounded_input = kwargs["bounded_input"], data_l = data_l, data_u = data_u)
# f = DualNetwork(model, data, convex_eps, proj = proj, norm_type = norm_type, bounded_input = kwargs["bounded_input"], data_l = data_l, data_u = data_u)
lb = f(c)
elif kwargs["bound_type"] == "interval":
ub, lb, relu_activity, unstable, dead, alive = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, method_opt="interval_range")
Expand Down Expand Up @@ -438,7 +438,7 @@ def Train(model, t, loader, eps_scheduler, max_eps, norm, logger, verbose, train
data_u = np.inf
else:
data_l = data_u = None
f = DualNetwork(model, data, convex_eps, proj = proj, norm_type = norm_type, bounded_input = kwargs["bounded_input"], data_l = data_l, data_u = data_u)
# f = DualNetwork(model, data, convex_eps, proj = proj, norm_type = norm_type, bounded_input = kwargs["bounded_input"], data_l = data_l, data_u = data_u)
lb = f(c)
elif kwargs["bound_type"] == "interval":
ub, lb, relu_activity, unstable, dead, alive = model(norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c, method_opt="interval_range")
Expand Down

0 comments on commit fc7eb63

Please sign in to comment.