Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonical recursive #432

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 73 additions & 57 deletions cpmpy/transformations/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .get_variables import get_variables
from ..exceptions import TransformationNotImplementedError

from ..expressions.core import Comparison, Operator, BoolVal
from ..expressions.core import Comparison, Operator, BoolVal, Expression
from ..expressions.globalconstraints import GlobalConstraint, DirectConstraint
from ..expressions.utils import is_any_list, is_num, eval_comparison, is_bool

Expand Down Expand Up @@ -255,71 +255,87 @@ def only_positive_bv(lst_of_expr):
return newlist

def canonical_comparison(lst_of_expr):

lst_of_expr = toplevel_list(lst_of_expr) # ensure it is a list
"""
Transform sum-comparison in canonical form:
- lhs contains only variables
- rhs consists of constant
(TODO: sorted args on lhs?)
"""

newlist = []
for cpm_expr in lst_of_expr:

if isinstance(cpm_expr, Operator) and cpm_expr.name == '->': # half reification of comparison
lhs, rhs = cpm_expr.args
if isinstance(rhs, Comparison):
rhs = canonical_comparison(rhs)[0]
newlist.append(lhs.implies(rhs))
elif isinstance(lhs, Comparison):
lhs = canonical_comparison(lhs)[0]
newlist.append(lhs.implies(rhs))
if isinstance(cpm_expr, _NumVarImpl) or not isinstance(cpm_expr, Expression):
newlist.append(cpm_expr)

if isinstance(cpm_expr, Comparison):
elif isinstance(cpm_expr, Comparison) and isinstance(cpm_expr.args[0], Expression) and \
(isinstance(cpm_expr.args[0], _NumVarImpl) or \
(isinstance(cpm_expr.args[0], Operator) and (cpm_expr.args[0].name == "sum" or cpm_expr.args[0].name == "wsum"))):

# LHS is sum/wsum/var
comp_name = cpm_expr.name
lhs, rhs = cpm_expr.args
if isinstance(lhs, Comparison) and cpm_expr.name == "==": # reification of comparison
lhs = canonical_comparison(lhs)[0]
elif is_num(lhs) or isinstance(lhs, _NumVarImpl) or (isinstance(lhs, Operator) and lhs.name in {"sum", "wsum"}):
# bring all vars to lhs
lhs2 = []
if isinstance(rhs, _NumVarImpl):
lhs2, rhs = [-1 * rhs], 0
elif isinstance(rhs, Operator) and rhs.name == "sum":
lhs2, rhs = [-1 * b if isinstance(b, _NumVarImpl) else 1 * b.args[0] for b in rhs.args
if isinstance(b, _NumVarImpl) or isinstance(b, Operator)], \
sum(b for b in rhs.args if is_num(b))
elif isinstance(rhs, Operator) and rhs.name == "wsum":
lhs2, rhs = [-a * b for a, b in zip(rhs.args[0], rhs.args[1])
if isinstance(b, _NumVarImpl)], \
sum(-a * b for a, b in zip(rhs.args[0], rhs.args[1])
if not isinstance(b, _NumVarImpl))
if isinstance(lhs, Operator) and lhs.name == "sum":
lhs, rhs = sum([1 * a for a in lhs.args] + lhs2), rhs
elif isinstance(lhs, _NumVarImpl) or (isinstance(lhs, Operator) and lhs.name == "wsum"):
lhs, rhs = lhs + lhs2, rhs
else:
raise ValueError(
f"unexpected expression on lhs of expression, should be sum,wsum or intvar but got {lhs}")

assert not is_num(lhs), "lhs cannot be an integer at this point!"
# put all vars to lhs
if not is_num(rhs) and (isinstance(lhs, _NumVarImpl) or lhs.name =="sum" or lhs.name == "wsum"):

# bring all const to rhs
if lhs.name == "sum":
new_args = []
for i, arg in enumerate(lhs.args):
if is_num(arg):
rhs -= arg
else:
new_args.append(arg)
lhs = Operator("sum", new_args)
if isinstance(rhs, Operator):
if rhs.name == "sum":
extra_weights = [-1] * len(rhs.args)
extra_vars = rhs.args
elif rhs.name == "wsum":
extra_weights= [-w for w in rhs.args[0]]
extra_vars = rhs.args[1]
elif isinstance(rhs, _NumVarImpl):
extra_weights = [-1]
extra_vars = [rhs]

if lhs.name == "sum":
lhs_weights = [1]*len(lhs.args) + extra_weights
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_weights will not be defined if the rhs is an operator other than sum or wsum (like div or mul or whatever)

lhs_args = lhs.args + extra_vars
elif lhs.name == "wsum":
new_weights, new_args = [], []
for i, (w, arg) in enumerate(zip(*lhs.args)):
if is_num(arg):
rhs -= w * arg
else:
new_weights.append(w)
new_args.append(arg)
lhs = Operator("wsum", [new_weights, new_args])

newlist.append(eval_comparison(cpm_expr.name, lhs, rhs))
else: # rest of expressions
lhs_weights = lhs.args[0] + extra_weights
lhs_args = lhs.args[1] + extra_vars
else: # lhs is constant
lhs_weights = [1] + extra_weights
lhs_args = [lhs] + extra_vars

lhs = Operator("wsum", [lhs_weights, lhs_args])
rhs = 0

# bring all consts to rhs
if lhs.name == "sum":
new_args = []
for arg in lhs.args:
if is_num(arg):
rhs -= arg
else:
new_args.append(arg)
lhs = Operator("sum", canonical_comparison(new_args))

elif lhs.name == "wsum":
new_weights, new_args = [], []
for i, (w, arg) in enumerate(zip(*lhs.args)):
if is_num(arg):
rhs -= w * arg
else:
new_weights.append(w)
new_args.append(arg)

lhs = Operator("wsum", [new_weights, canonical_comparison(new_args)])

newlist.append(eval_comparison(comp_name, lhs, rhs))

elif isinstance(cpm_expr, DirectConstraint):
newlist.append(cpm_expr) # we do not mess with direct constraints

elif isinstance(cpm_expr, Expression) and hasattr(cpm_expr, "args"):
# look in args of other expressions, including global constraints
cpm_expr = copy.copy(cpm_expr)
cpm_expr.args = canonical_comparison(cpm_expr.args)
newlist.append(cpm_expr)

return newlist
else:
raise ValueError("Reached uncovered case, if you reach this, please report on github ")

return newlist
10 changes: 8 additions & 2 deletions tests/test_trans_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def test_sum(self):
cons = canonical_comparison([cp.sum([a,b,c,10]) <= rhs])[0]
self.assertEqual("sum([a, b, c]) <= -5", str(cons))

cons = canonical_comparison([(cp.sum([a, b, c, 10]) <= rhs) == (cp.sum([a]) <= b)])[0]
self.assertEqual('(sum([a, b, c]) <= -5) == (sum([1, -1] * [a, b]) <= 0)', str(cons))

rhs = cp.sum([b,c])
cons = canonical_comparison([cp.sum([a, b]) <= rhs])[0]
self.assertEqual("sum([1, 1, -1, -1] * [a, b, b, c]) <= 0", str(cons))
Expand All @@ -209,8 +212,11 @@ def test_div(self):
self.assertEqual("(a) // (b) <= 5", str(cons))

#when adding division
#cons = canonical_comparison([a / b <= c / rhs])[0]
#cons = canonical_comparison([a + b <= c/rhs])[0]
cons = canonical_comparison([a / b <= c / rhs])[0]
self.assertEqual('?', str(cons))

cons = canonical_comparison([a + b <= c * a])[0]
self.assertEqual('(sum([a, b]) - c * a <= 0)', str(cons))


def test_wsum(self):
Expand Down
Loading