Skip to content

Commit ab1f02c

Browse files
committed
bind values of previously picked struct fields while picking new ones.
1 parent d7683d7 commit ab1f02c

File tree

3 files changed

+125
-86
lines changed

3 files changed

+125
-86
lines changed

lib/dialect/include/rlc/dialect/DynamicArgumentAnalysis.hpp

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <cassert>
22
#include <cstdint>
3+
#include <ranges>
34
#include <strings.h>
5+
#include "llvm/ADT/DenseMap.h"
46
#include "llvm/ADT/SmallVector.h"
57
#include "mlir/IR/Builders.h"
68
#include "mlir/IR/Location.h"
@@ -18,7 +20,7 @@ struct DeducedConstraints {
1820
};
1921

2022
enum TermType {
21-
DEPENDS_ON_UNBOUND_VALUE,
23+
DEPENDS_ON_UNBOUND,
2224
DEPENDS_ON_OTHER_UNKNOWNS,
2325
KNOWN_VALUE
2426
};
@@ -65,27 +67,82 @@ enum TermType {
6567
6668
and similarly for the maximum.
6769
*/
70+
71+
72+
/*
73+
memberAddress is the "path" from the arg to the value we want to pick.
74+
Example: arg = arg2, memberAddress = [2, 1] maps to MemberAccess(MemberAccess(arg,2), 1)
75+
*/
76+
struct UnboundValue {
77+
mlir::Value argument;
78+
llvm::SmallVector<uint64_t> memberAddress;
79+
80+
/*
81+
Returns whether this unbound value corressponds to the term.
82+
*/
83+
bool matches(mlir::Value term) {
84+
mlir::Value current = term;
85+
// walk the member address in reverse, test if it leads to the argument.
86+
for(uint64_t & index : std::ranges::reverse_view(memberAddress)) {
87+
auto definingOp = current.getDefiningOp();
88+
if( not llvm::detail::isPresent(definingOp))
89+
return false;
90+
if(auto memberAccess = mlir::dyn_cast<mlir::rlc::MemberAccess>(definingOp)) {
91+
if (memberAccess.getMemberIndex() != index) {
92+
return false;
93+
}
94+
current = memberAccess.getValue();
95+
} else {
96+
return false;
97+
}
98+
}
99+
return current == argument;
100+
}
101+
102+
mlir::Type getType() {
103+
auto type = argument.getType();
104+
for (auto index : memberAddress) {
105+
type = type.cast<mlir::rlc::EntityType>().getBody()[index];
106+
}
107+
return type;
108+
}
109+
};
110+
68111
class DynamicArgumentAnalysis
69112
{
70113
public:
71-
explicit DynamicArgumentAnalysis(mlir::rlc::FunctionOp op, mlir::ValueRange knownArgs, mlir::Value argPicker, mlir::OpBuilder builder, mlir::Location loc);
114+
explicit DynamicArgumentAnalysis(mlir::rlc::FunctionOp op, mlir::ValueRange boundArgs, mlir::Value argPicker, mlir::OpBuilder builder, mlir::Location loc);
72115
mlir::Value pickArg(int argIndex);
73116

74117
private:
75-
DeducedConstraints deduceIntegerUnboundValueConstraints(mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress);
118+
DeducedConstraints deduceIntegerUnboundValueConstraints(UnboundValue unbound);
76119
llvm::SmallVector<llvm::SmallVector<mlir::Value>> expandToDNF(mlir::Value constraint);
77-
TermType decideTermType(mlir::Value term, mlir::Value argument, mlir::SmallVector<uint64_t> memberAddress);
120+
TermType decideTermType(mlir::Value term, UnboundValue unbound);
78121
mlir::Value compute(mlir::Value expression);
79-
DeducedConstraints findImposedConstraints(mlir::Value constraint, mlir::Value arg, mlir::SmallVector<uint64_t> memberAddress);
80-
DeducedConstraints findImposedConstraints(mlir::Operation *binaryOperation, mlir::Value arg, mlir::SmallVector<uint64_t> memberAddress);
81-
DeducedConstraints findImposedConstraints(mlir::rlc::CallOp call, mlir::Value arg, mlir::SmallVector<uint64_t> memberAddress);
122+
DeducedConstraints findImposedConstraints(mlir::Value constraint, UnboundValue unbound);
123+
DeducedConstraints findImposedConstraints(mlir::Operation *binaryOperation, UnboundValue unbound);
124+
DeducedConstraints findImposedConstraints(mlir::rlc::CallOp call, UnboundValue unbound);
82125

83-
mlir::Value pickIntegerUnboundValue(mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress);
84-
mlir::Value pickUnboundValue(mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress);
126+
mlir::Value pickIntegerUnboundValue(UnboundValue unbound);
127+
mlir::Value pickUnboundValue(UnboundValue unbound);
85128

86129
mlir::rlc::FunctionOp function;
87130
mlir::Region& precondition;
88-
mlir::ValueRange knownArgs;
131+
132+
llvm::SmallVector<std::pair<UnboundValue, mlir::Value>> bindings;
133+
/*
134+
If the passed value matches an UnboundValue that has a binding,
135+
return the value bound to it.
136+
Otherwise returns nullptr.
137+
*/
138+
mlir::Value getBoundValue(mlir::Value expr) {
139+
for(auto binding : bindings) {
140+
if(binding.first.matches(expr))
141+
return binding.second;
142+
}
143+
return nullptr;
144+
}
145+
89146
mlir::Value argPicker;
90147
mlir::OpBuilder builder;
91148
mlir::Location loc;

lib/dialect/src/DynamicArgumentAnalysisPass.cpp

Lines changed: 49 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#include <cassert>
22
#include <cstdint>
3-
#include <iterator>
4-
#include <ranges>
53
#include <strings.h>
4+
#include <utility>
65
#include "llvm/ADT/STLExtras.h"
76
#include "llvm/ADT/SmallVector.h"
87
#include "llvm/Support/Casting.h"
@@ -20,13 +19,17 @@
2019
#include "rlc/dialect/DynamicArgumentAnalysis.hpp"
2120
#include "rlc/dialect/Types.hpp"
2221

23-
DynamicArgumentAnalysis::DynamicArgumentAnalysis(mlir::rlc::FunctionOp op, mlir::ValueRange knownArgs, mlir::Value argPicker, mlir::OpBuilder builder, mlir::Location loc):
22+
DynamicArgumentAnalysis::DynamicArgumentAnalysis(mlir::rlc::FunctionOp op, mlir::ValueRange boundArgs, mlir::Value argPicker, mlir::OpBuilder builder, mlir::Location loc):
2423
function(op),
2524
precondition(op.getPrecondition()),
26-
knownArgs(knownArgs),
2725
argPicker(argPicker),
2826
builder(builder),
2927
loc(loc) {
28+
for (auto boundArg : llvm::enumerate(boundArgs)) {
29+
UnboundValue arg {op.getPrecondition().getArgument(boundArg.index()), {}};
30+
bindings.emplace_back(std::pair(arg, boundArg.value()));
31+
}
32+
3033
auto yield = mlir::dyn_cast<mlir::rlc::Yield>(precondition.getBlocks().front().back());
3134
assert(yield->getNumOperands() == 1);
3235
conjunctions = expandToDNF(yield->getOperand(0));
@@ -82,35 +85,17 @@ llvm::SmallVector<llvm::SmallVector<mlir::Value>> DynamicArgumentAnalysis::expan
8285
return conjunctions;
8386
}
8487

85-
bool matches(mlir::Value term, mlir::Value argument, llvm::SmallVector<uint64_t> memberAddress) {
86-
mlir::Value current = term;
87-
// walk the member address in reverse, test if it leads to the argument.
88-
for(uint64_t & index : std::ranges::reverse_view(memberAddress)) {
89-
auto definingOp = current.getDefiningOp();
90-
if( not llvm::detail::isPresent(definingOp))
91-
return false;
92-
if(auto memberAccess = mlir::dyn_cast<mlir::rlc::MemberAccess>(definingOp)) {
93-
if (memberAccess.getMemberIndex() != index) {
94-
return false;
95-
}
96-
current = memberAccess.getValue();
97-
} else {
98-
return false;
99-
}
100-
}
101-
return current == argument;
102-
}
88+
TermType DynamicArgumentAnalysis::decideTermType(mlir::Value term, UnboundValue unbound) {
89+
if(unbound.matches(term))
90+
return DEPENDS_ON_UNBOUND;
10391

104-
TermType DynamicArgumentAnalysis::decideTermType(mlir::Value term, mlir::Value argument, llvm::SmallVector<uint64_t> memberAddress) {
105-
if(matches(term, argument, memberAddress))
106-
return DEPENDS_ON_UNBOUND_VALUE;
92+
auto boundValue = getBoundValue(term);
93+
if(boundValue != nullptr)
94+
return KNOWN_VALUE;
10795

108-
if(auto arg = llvm::dyn_cast<mlir::BlockArgument>(term)) {
109-
if(arg.getArgNumber() < knownArgs.size())
110-
return KNOWN_VALUE;
111-
96+
if(term.isa<mlir::BlockArgument>())
11297
return DEPENDS_ON_OTHER_UNKNOWNS;
113-
}
98+
11499

115100
auto definingOp = term.getDefiningOp();
116101

@@ -124,23 +109,24 @@ TermType DynamicArgumentAnalysis::decideTermType(mlir::Value term, mlir::Value a
124109
bool dependsOnOtherUnkowns = false;
125110

126111
for(auto operand : definingOp->getOperands()) {
127-
auto operandType = decideTermType(operand, argument, memberAddress);
128-
if(operandType == DEPENDS_ON_UNBOUND_VALUE)
112+
auto operandType = decideTermType(operand, unbound);
113+
if(operandType == DEPENDS_ON_UNBOUND)
129114
dependsOnArg = true;
130115
else if (operandType == DEPENDS_ON_OTHER_UNKNOWNS)
131116
dependsOnOtherUnkowns = true;
132117
}
133118

134-
if (dependsOnArg) { return DEPENDS_ON_UNBOUND_VALUE; }
119+
if (dependsOnArg) { return DEPENDS_ON_UNBOUND; }
135120
if (dependsOnOtherUnkowns) { return DEPENDS_ON_OTHER_UNKNOWNS; };
136121
return KNOWN_VALUE;
137122
}
138123

139124
mlir::Value DynamicArgumentAnalysis::compute(mlir::Value expression) {
140-
if(auto arg = llvm::dyn_cast<mlir::BlockArgument>(expression)) {
141-
assert(arg.getArgNumber() < knownArgs.size());
142-
return knownArgs[arg.getArgNumber()];
143-
}
125+
auto boundValue = getBoundValue(expression);
126+
if(boundValue != nullptr)
127+
return boundValue;
128+
129+
assert(not expression.isa<mlir::BlockArgument>() && "The expression to be computed depends on an unbound argument.");
144130

145131
if (expression.getDefiningOp()->getParentRegion() != precondition)
146132
return expression;
@@ -166,8 +152,8 @@ mlir::Value DynamicArgumentAnalysis::compute(mlir::Value expression) {
166152
const int64_t min_int = -800;
167153
const int64_t max_int = 800;
168154

169-
DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::Operation *binaryOperation, mlir::Value arg, mlir::SmallVector<uint64_t> memberAddress) {
170-
if (matches(binaryOperation->getOperand(0), arg, memberAddress)) {
155+
DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::Operation *binaryOperation, UnboundValue unbound) {
156+
if (unbound.matches(binaryOperation->getOperand(0))) {
171157
auto rhs = compute(binaryOperation->getOperand(1));
172158

173159
if (mlir::isa<mlir::rlc::LessOp>(binaryOperation))
@@ -213,7 +199,7 @@ DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::Operati
213199
}
214200
}
215201

216-
if(matches(binaryOperation->getOperand(1), arg, memberAddress)) {
202+
if(unbound.matches(binaryOperation->getOperand(1))) {
217203
auto lhs = compute(binaryOperation->getOperand(0));
218204

219205
if (mlir::isa<mlir::rlc::LessOp>(binaryOperation))
@@ -264,7 +250,7 @@ DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::Operati
264250
};
265251
}
266252

267-
DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::rlc::CallOp call, mlir::Value arg, mlir::SmallVector<uint64_t> memberAddress) {
253+
DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::rlc::CallOp call, UnboundValue unbound) {
268254
// LARGE TODO think about this part.
269255
if( mlir::rlc::CanOp can = llvm::dyn_cast<mlir::rlc::CanOp>(call.getCallee().getDefiningOp())) {
270256
auto underlyingFunction = llvm::dyn_cast<mlir::rlc::FunctionOp>(*can.getCallee().getDefiningOp());
@@ -279,14 +265,15 @@ DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::rlc::Ca
279265
continue;
280266
}
281267
// and find the index of the argument we're interested in.
282-
if(matches(current.value(), arg, memberAddress)) {
268+
if(unbound.matches(current.value())) {
283269
argIndex = current.index();
284270
break;
285271
}
286272
}
287273
assert(argIndex != -1 && "Expected to find the argument.");
288274
DynamicArgumentAnalysis analysis(underlyingFunction, knownArgsOfUnderlyingFunction, argPicker, builder, loc);
289-
return analysis.deduceIntegerUnboundValueConstraints(underlyingFunction.getPrecondition().getArgument(argIndex), {});
275+
UnboundValue correspongindUnboundValue {underlyingFunction.getPrecondition().getArgument(argIndex), {}};
276+
return analysis.deduceIntegerUnboundValueConstraints(correspongindUnboundValue);
290277
}
291278
return {
292279
builder.create<mlir::rlc::Constant>(loc, min_int),
@@ -295,14 +282,14 @@ DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::rlc::Ca
295282
}
296283

297284

298-
DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::Value constraint, mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress) {
285+
DeducedConstraints DynamicArgumentAnalysis::findImposedConstraints(mlir::Value constraint, UnboundValue unbound) {
299286
auto *definingOp = constraint.getDefiningOp();
300287
if (definingOp->getOperands().size() == 2) {
301-
return findImposedConstraints(definingOp, arg, memberAddress);
288+
return findImposedConstraints(definingOp, unbound);
302289
}
303290

304291
if( mlir::rlc::CallOp call = llvm::dyn_cast<mlir::rlc::CallOp>(definingOp)) {
305-
return findImposedConstraints(call, arg, memberAddress);
292+
return findImposedConstraints(call, unbound);
306293
}
307294

308295
return {
@@ -387,16 +374,8 @@ void maybeAssignMax(mlir::Value currentMax, mlir::Value aggregateMax, mlir::Valu
387374
builder.setInsertionPointAfter(maybeAssignMax);
388375
}
389376

390-
mlir::Type getUnboundValueType( mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress) {
391-
auto type = arg.getType();
392-
for (auto index : memberAddress) {
393-
type = type.cast<mlir::rlc::EntityType>().getBody()[index];
394-
}
395-
return type;
396-
}
397-
398-
DeducedConstraints DynamicArgumentAnalysis::deduceIntegerUnboundValueConstraints(mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress) {
399-
auto type = getUnboundValueType(arg, memberAddress);
377+
DeducedConstraints DynamicArgumentAnalysis::deduceIntegerUnboundValueConstraints(UnboundValue unbound) {
378+
auto type = unbound.getType();
400379
assert(type.isa<mlir::rlc::IntegerType>() && "Expected an integer.");
401380

402381
auto minVal = builder.create<mlir::rlc::UninitializedConstruct>(loc, type);
@@ -412,8 +391,8 @@ DeducedConstraints DynamicArgumentAnalysis::deduceIntegerUnboundValueConstraints
412391
llvm::SmallVector<mlir::Value> conditions;
413392
// categorize the terms in the conjunction
414393
for(auto term : conjunction) {
415-
TermType type = decideTermType(term, arg, memberAddress);
416-
if(type == DEPENDS_ON_UNBOUND_VALUE) {
394+
TermType type = decideTermType(term, unbound);
395+
if(type == DEPENDS_ON_UNBOUND) {
417396
constraints.emplace_back(term);
418397
} else if (type == KNOWN_VALUE){
419398
conditions.emplace_back(term);
@@ -452,7 +431,7 @@ DeducedConstraints DynamicArgumentAnalysis::deduceIntegerUnboundValueConstraints
452431

453432
builder.createBlock(&ifStatement.getTrueBranch());
454433
for(auto constraint : constraints) {
455-
auto imposedConstraints = findImposedConstraints(constraint, arg, memberAddress);
434+
auto imposedConstraints = findImposedConstraints(constraint, unbound);
456435

457436
// if the minimum imposed by this constraint is greater than the current minimum, set the current minimum.
458437
assignIfGreaterthan(imposedConstraints.min, minForThisConjunction, builder, loc);
@@ -475,8 +454,8 @@ DeducedConstraints DynamicArgumentAnalysis::deduceIntegerUnboundValueConstraints
475454
return {minVal, maxVal};
476455
}
477456

478-
mlir::Value DynamicArgumentAnalysis::pickIntegerUnboundValue(mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress) {
479-
auto deduced = deduceIntegerUnboundValueConstraints(arg, memberAddress);
457+
mlir::Value DynamicArgumentAnalysis::pickIntegerUnboundValue(UnboundValue unbound) {
458+
auto deduced = deduceIntegerUnboundValueConstraints(unbound);
480459
auto call = builder.create<mlir::rlc::CallOp>(
481460
loc,
482461
argPicker,
@@ -486,24 +465,23 @@ mlir::Value DynamicArgumentAnalysis::pickIntegerUnboundValue(mlir::Value arg, ll
486465
return call.getResult(0);
487466
}
488467

489-
/*
490-
memberAddress is the "path" from the arg to the value we want to pick.
491-
Example: arg = arg2, memberAddress = [2, 1] maps to MemberAccess(MemberAccess(arg,2), 1)
492-
*/
493-
mlir::Value DynamicArgumentAnalysis::pickUnboundValue(mlir::Value arg, llvm::SmallVector<uint64_t> memberAddress) {
494-
auto type = getUnboundValueType(arg, memberAddress);
468+
mlir::Value DynamicArgumentAnalysis::pickUnboundValue(UnboundValue unbound) {
469+
auto type = unbound.getType();
495470
if(type.isa<mlir::rlc::IntegerType>()) {
496-
return pickIntegerUnboundValue(arg, memberAddress);
471+
return pickIntegerUnboundValue(unbound);
497472
}
498473

499474
if(auto entityType = mlir::dyn_cast<mlir::rlc::EntityType>(type)) {
500475
auto entity = builder.create<mlir::rlc::UninitializedConstruct>(loc, entityType);
501476
for( auto memberType : llvm::enumerate(entityType.getBody())) {
502-
llvm::SmallVector<uint64_t> newMemberAddress(memberAddress);
477+
llvm::SmallVector<uint64_t> newMemberAddress(unbound.memberAddress);
503478
newMemberAddress.emplace_back(memberType.index());
504-
auto value = pickUnboundValue(arg, newMemberAddress);
479+
UnboundValue newUnboundValue {unbound.argument, newMemberAddress};
480+
auto value = pickUnboundValue(newUnboundValue);
505481
auto access = builder.create<mlir::rlc::MemberAccess>(loc, entity, memberType.index());
506482
builder.create<mlir::rlc::BuiltinAssignOp>(loc, access, value);
483+
// bind the value of this struct field while we pick the next field.
484+
bindings.emplace_back(std::pair(newUnboundValue, value));
507485
}
508486
return entity;
509487
}
@@ -514,7 +492,7 @@ mlir::Value DynamicArgumentAnalysis::pickUnboundValue(mlir::Value arg, llvm::Sma
514492

515493
mlir::Value DynamicArgumentAnalysis::pickArg(int argIndex) {
516494
auto arg = function.getPrecondition().getArgument(argIndex);
517-
return pickUnboundValue(arg, {});
495+
return pickUnboundValue({arg, {}});
518496
}
519497

520498
namespace mlir::rlc

tool/rlc/test/fuzzer_pick_struct.rl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ import fuzzer.cpp_functions
22
import fuzzer.utils
33

44
ent Test:
5-
Int field1
6-
Int field2
5+
Int a
6+
Int b
77

88
act play() -> Play:
9-
act uses_struct(Test t) {t.field1 > 0, t.field1 < 5, t.field2 > 7, t.field2 < 16}
10-
11-
9+
act uses_struct(Test t) {
10+
t.a >= 0,
11+
t.a <= 5,
12+
t.b >= -10,
13+
t.a <= 2 or t.b >= 5,
14+
t.b <= 16
15+
}

0 commit comments

Comments
 (0)