You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! Thank you for the reply to the question about the 'CROWN' repository!
I want to calculate the minimum r as the same as the 'CROWN', and I tried the code as follows:
def test(input,model):
eps = 0
gap_gx = 100
eps_LB = -1
eps_UB = 1
counter = 0
is_pos = True
is_neg = True
Hi! Thank you for the reply to the question about the 'CROWN' repository!
I want to calculate the minimum r as the same as the 'CROWN', and I tried the code as follows:
def test(input,model):
eps = 0
gap_gx = 100
eps_LB = -1
eps_UB = 1
counter = 0
is_pos = True
is_neg = True
perform binary search
eps_gx_UB = 1000000.0
eps_gx_LB = 0.0
is_pos = True
is_neg = True
eps = eps_gx_LB*2
eps = args.eps
while eps_gx_UB - eps_gx_LB > 0.00001:
ptb = PerturbationLpNorm(norm=2, eps=eps)
image = BoundedTensor(input, ptb)
pred = model(image)
label = torch.argmax(pred, dim=1).cpu().numpy()
# for method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)']:
lb, ub = model.compute_bounds(x=(image,), method='IBP+backward')
gap_gx = torch.min(lb)
lb = lb.detach().cpu().numpy()
ub = ub.detach().cpu().numpy()
print("Bounding method:", method)
for i in range(N):
print("Image {} top-1 prediction {} ground-truth {}".format(i, label[i], true_label[i]))
for j in range(n_classes):
indicator = '(ground-truth)' if j == true_label[i] else ''
print("f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}".format(
j=j, l=lb[i][j], u=ub[i][j], ind=indicator))
print()
if gap_gx > 0:
if gap_gx < 0.01:
eps_gx_LB = eps
return eps
break
if is_pos: # so far always > 0, haven't found eps_UB
eps_gx_LB = eps
eps *= 10
else:
eps_gx_LB = eps
eps = (eps_gx_LB + eps_gx_UB) / 2
is_neg = False
else:
if is_neg: # so far always < 0, haven't found eps_LB
eps_gx_UB = eps
eps /= 10
else:
eps_gx_UB = eps
eps = (eps_gx_LB + eps_gx_UB) / 2
is_pos = False
counter += 1
if counter >= 500:
return eps
break
print("[L2][binary search] step = {}, eps = {:.5f}, gap_gx = {:.2f}".format(counter, eps, gap_gx))
But the result is failed and the lp cannot change with the eps. I'm confused about it. Thank you!
The text was updated successfully, but these errors were encountered: