Skip to content

Commit

Permalink
update & fix on ipfp backends
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Apr 26, 2024
1 parent a6dc0c7 commit 003b097
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 48 deletions.
22 changes: 15 additions & 7 deletions pygmtools/jittor_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def ipfp(K: Var, n1: Var, n2: Var, n1max, n2max, x0: Var,
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = jt.full((batch_num, 1, 1), -1)

def comp_obj_score(v1, K, v2):
return jt.bmm(jt.bmm(v1.view(batch_num, 1, -1), K), v2)
Expand All @@ -293,19 +295,25 @@ def comp_obj_score(v1, K, v2):
binary_v = binary_sol.transpose(1, 2).view(batch_num, -1, 1)
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
cond = jt.logical_or(beta <= 0, t0 >= 1)
t0 = - alpha / beta
cond = jt.logical_or(beta >= 0, t0 >= 1)
if cond.shape != binary_v.shape:
cond = cond.expand(binary_v.shape)
v = jt.where(cond, binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if jt.max(jt.abs(
last_v_sol - jt.bmm(cost.reshape(batch_num, 1, -1), binary_sol.reshape(batch_num, -1, 1))
) / last_v_sol) < 1e-3:
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
cond = current_obj > best_obj
if cond.shape != binary_v.shape:
cond = cond.expand(binary_v.shape)
best_v = jt.where(cond, binary_v, best_v) # current_obj > best_obj
best_obj = jt.where(current_obj > best_obj, current_obj, best_obj)

if jt.max(jt.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape((batch_num, int(n2max), int(n1max))).transpose(1, 2)
return pred_x


Expand Down
22 changes: 13 additions & 9 deletions pygmtools/mindspore_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def ipfp(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max,
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return mindspore.ops.BatchMatMul()(mindspore.ops.BatchMatMul()(v1.view(batch_num, 1, -1), K), v2)
Expand All @@ -288,19 +290,21 @@ def comp_obj_score(v1, K, v2):
cost = mindspore.ops.BatchMatMul()(K, v).reshape(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)
binary_sol = hungarian(cost, n1, n2)
binary_v = binary_sol.swapaxes(1, 2).view(batch_num, -1, 1)
alpha = comp_obj_score(v, K, binary_v - v) # + torch.mm(k_diag.view(1, -1), (binary_sol - v).view(-1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = mindspore.numpy.where(mindspore.ops.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if (mindspore.ops.max(mindspore.ops.abs(
last_v_sol - mindspore.ops.BatchMatMul()(cost.reshape((batch_num, 1, -1)),
binary_sol.reshape((batch_num, -1, 1)))
) / last_v_sol)[1] < 1e-3).any():
t0 = - alpha / beta
v = mindspore.numpy.where(mindspore.ops.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = mindspore.numpy.where(current_obj > best_obj, binary_v, best_v)
best_obj = mindspore.numpy.where(current_obj > best_obj, current_obj, best_obj)

if (mindspore.ops.max(mindspore.ops.abs(last_v_obj - current_obj) / last_v_obj)[1] < 1e-3).any():
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape(batch_num, int(n2max), int(n1max)).swapaxes(1, 2)
return pred_x


Expand Down
19 changes: 12 additions & 7 deletions pygmtools/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def ipfp(K: np.ndarray, n1: np.ndarray, n2: np.ndarray, n1max, n2max, x0: np.nda
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return np.matmul(np.matmul(v1.reshape((batch_num, 1, -1)), K), v2)
Expand All @@ -303,16 +305,19 @@ def comp_obj_score(v1, K, v2):
binary_v = binary_sol.transpose((0, 2, 1)).reshape((batch_num, -1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = np.where(np.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if np.max(np.abs(
last_v_sol - np.matmul(cost.reshape((batch_num, 1, -1)), binary_sol.reshape((batch_num, -1, 1)))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = np.where(np.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = np.where(current_obj > best_obj, binary_v, best_v)
best_obj = np.where(current_obj > best_obj, current_obj, best_obj)

if np.max(np.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape((batch_num, n2max, n1max)).transpose((0, 2, 1))
return pred_x


Expand Down
19 changes: 12 additions & 7 deletions pygmtools/paddle_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def ipfp(K: paddle.Tensor, n1: paddle.Tensor, n2: paddle.Tensor, n1max, n2max, x
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = paddle.to_tensor(paddle.full((batch_num, 1, 1), -1.), place=K.place)

def comp_obj_score(v1, K, v2):
return paddle.bmm(paddle.bmm(paddle.reshape(v1, (batch_num, 1, -1)), K), v2)
Expand All @@ -285,16 +287,19 @@ def comp_obj_score(v1, K, v2):
binary_v = paddle.reshape(binary_sol.transpose((0, 2, 1)),(batch_num, -1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = paddle.where(paddle.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if paddle.max(paddle.abs(
last_v_sol - paddle.bmm(paddle.reshape(cost,(batch_num, 1, -1)), paddle.reshape(binary_sol, (batch_num, -1, 1)))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = paddle.where(paddle.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = paddle.where(current_obj > best_obj, binary_v, best_v)
best_obj = paddle.where(current_obj > best_obj, current_obj, best_obj)

if paddle.max(paddle.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = paddle.reshape(best_v, (batch_num, n2max, n1max)).transpose((0, 2, 1))
return pred_x


Expand Down
21 changes: 13 additions & 8 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def ipfp(K: Tensor, n1: Tensor, n2: Tensor, n1max, n2max, x0: Tensor,
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return torch.bmm(torch.bmm(v1.view(batch_num, 1, -1), K), v2)
Expand All @@ -290,18 +292,21 @@ def comp_obj_score(v1, K, v2):
cost = torch.bmm(K, v).reshape(batch_num, n2max, n1max).transpose(1, 2)
binary_sol = hungarian(cost, n1, n2)
binary_v = binary_sol.transpose(1, 2).view(batch_num, -1, 1)
alpha = comp_obj_score(v, K, binary_v - v) # + torch.mm(k_diag.view(1, -1), (binary_sol - v).view(-1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = torch.where(torch.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if torch.max(torch.abs(
last_v_sol - torch.bmm(cost.reshape(batch_num, 1, -1), binary_sol.reshape(batch_num, -1, 1))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = torch.where(torch.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = torch.where(current_obj > best_obj, binary_v, best_v)
best_obj = torch.where(current_obj > best_obj, current_obj, best_obj)

if torch.max(torch.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = best_v.reshape(batch_num, n2max, n1max).transpose(1, 2)
return pred_x


Expand Down
25 changes: 15 additions & 10 deletions pygmtools/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def sinkhorn(s: tf.Tensor, nrows: tf.Tensor=None, ncols: tf.Tensor=None,
def rrwm(K: tf.Tensor, n1: tf.Tensor, n2: tf.Tensor, n1max, n2max, x0: tf.Tensor,
max_iter: int, sk_iter: int, alpha: float, beta: float) -> tf.Tensor:
"""
Pytorch implementation of RRWM algorithm.
Tensorflow implementation of RRWM algorithm.
"""
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
# rescale the values in K
Expand Down Expand Up @@ -283,11 +283,13 @@ def sm(K: tf.Tensor, n1: tf.Tensor, n2: tf.Tensor, n1max, n2max, x0: tf.Tensor,
def ipfp(K: tf.Tensor, n1: tf.Tensor, n2: tf.Tensor, n1max, n2max, x0: tf.Tensor,
max_iter) -> tf.Tensor:
"""
Pytorch implementation of IPFP algorithm
Tensorflow implementation of IPFP algorithm
"""
batch_num, n1, n2, n1max, n2max, n1n2, v0 = _check_and_init_gm(K, n1, n2, n1max, n2max, x0)
v = v0
last_v = v
best_v = v
best_obj = -1

def comp_obj_score(v1, K, v2):
return tf.matmul(tf.matmul(tf.reshape(v1, [batch_num, 1, -1]), K), v2)
Expand All @@ -296,18 +298,21 @@ def comp_obj_score(v1, K, v2):
cost = tf.transpose(tf.reshape(tf.matmul(K, v), [batch_num, n2max, n1max]), [0, 2, 1])
binary_sol = hungarian(cost, n1, n2)
binary_v = tf.reshape(tf.transpose(binary_sol, [0, 2, 1]), [batch_num, -1, 1])
alpha = comp_obj_score(v, K, binary_v - v) # + torch.mm(k_diag.view(1, -1), (binary_sol - v).view(-1, 1))
alpha = comp_obj_score(v, K, binary_v - v)
beta = comp_obj_score(binary_v - v, K, binary_v - v)
t0 = alpha / beta
v = tf.where(tf.math.logical_or(beta <= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_sol = comp_obj_score(last_v, K, last_v)
if tf.reduce_max(tf.abs(
last_v_sol - tf.matmul(tf.reshape(cost, [batch_num, 1, -1]), tf.reshape(binary_sol, [batch_num, -1, 1]))
) / last_v_sol) < 1e-3:
t0 = - alpha / beta
v = tf.where(tf.math.logical_or(beta >= 0, t0 >= 1), binary_v, v + t0 * (binary_v - v))
last_v_obj = comp_obj_score(last_v, K, last_v)

current_obj = comp_obj_score(binary_v, K, binary_v)
best_v = tf.where(current_obj > best_obj, binary_v, best_v)
best_obj = tf.where(current_obj > best_obj, current_obj, best_obj)

if tf.reduce_max(tf.abs(last_v_obj - current_obj) / last_v_obj) < 1e-3:
break
last_v = v

pred_x = binary_sol
pred_x = tf.transpose(tf.reshape(best_v, [batch_num, n2max, n1max]), [0, 2, 1])
return pred_x


Expand Down

0 comments on commit 003b097

Please sign in to comment.