Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def Neura_FDivOp : Op<NeuraDialect, "fdiv"> {
Example:
%result = neura.fdiv %a, %b : f32
}];
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs);
let results = (outs AnyType:$result);
let traits = [SameOperandsAndResultElementType];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ bool HeuristicMapping::mapWithBacktrack(
<< sorted_ops_with_levels.size() - materialized_ops.size()
<< " non-materialized operations, " << materialized_ops.size()
<< " operations require physical mapping." << "\n";

llvm::outs() << "[HeuristicMapping] Materialized operations list:\n";
for (size_t i = 0; i < materialized_ops.size(); ++i) {
llvm::outs() << i << " " << *materialized_ops[i].first
<< " (level: " << materialized_ops[i].second << ")\n";
}

// Stores the mapping state snapshots for backtracking.
std::vector<MappingStateSnapshot> snapshots;
Expand Down
131 changes: 122 additions & 9 deletions lib/NeuraDialect/Mapping/MappingState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
using namespace mlir;
using namespace mlir::neura;

// Constants for table formatting in dumpOpToLocs.
// Total column width including separators.
constexpr int kKeyMaxLen = 36;
// Actual cell content width (35).
constexpr int kCellWidth = kKeyMaxLen - 1;

MappingState::MappingState(const Architecture &arch, int II,
bool is_spatial_only)
: II(II), is_spatial_only(is_spatial_only) {}
Expand Down Expand Up @@ -200,21 +206,128 @@ void MappingState::releaseRoute(Operation *op) {
}

void MappingState::dumpOpToLocs(llvm::raw_ostream &os) const {
os << "=== MappingState: op_to_locs ===\n";
os << "=== MappingState: Resource Allocation Table ===\n";

// Collects all tiles and time steps (modulo II).
std::set<int> tile_ids;
// Time slots range from 0 to II-1.
std::set<int> time_slots;
// Maps (tile_id, time_slot) to list of (operation, actual_time_step).
std::map<std::pair<int, int>, std::vector<std::pair<Operation*, int>>> tile_slot_to_ops;

for (const auto &[op, locs] : op_to_locs) {
os << " - " << op->getName();
if (auto name_attr = op->getAttrOfType<StringAttr>("sym_name")) {
os << " @" << name_attr;
}
os << "\n";

for (const MappingLoc &loc : locs) {
auto *res = loc.resource;
os << " -> " << res->getType() << "#" << res->getId()
<< " @t=" << loc.time_step << "\n";
// Only shows tiles in the table.
if (res->getType() == "tile") {
tile_ids.insert(res->getId());
// Computes modulo II.
int time_slot = loc.time_step % II;
time_slots.insert(time_slot);
tile_slot_to_ops[{res->getId(), time_slot}].push_back({op, loc.time_step});
}
}
}

if (tile_ids.empty() || time_slots.empty()) {
os << "No tile operations mapped.\n";
os << "=== End ===\n";
return;
}

os << "II = " << II << "\n";

// Prints header - time slots (0 to II-1) as columns.
os << "\nTile | ";
for (int slot : time_slots) {
os << "t%" << II << "=" << slot;
int padding = kKeyMaxLen - (II < 10 ? 5 : 6) - (slot < 10 ? 1 : 2);
for (int i = 0; i < padding; ++i) os << " ";
os << " | ";
}
os << "\n";

// Prints separator line.
os << "---------+";
for (size_t i = 0; i < time_slots.size(); ++i) {
for (int j = 0; j < kKeyMaxLen + 1; ++j) os << "-";
os << "+";
}
os << "\n";

// Prints each tile as a row.
for (int tile_id : tile_ids) {
os << "Tile#" << tile_id;
if (tile_id < 10) os << " ";
else if (tile_id < 100) os << " ";
os << " | ";

for (int slot : time_slots) {
auto it = tile_slot_to_ops.find({tile_id, slot});
if (it != tile_slot_to_ops.end() && !it->second.empty()) {
// Multiple operations may exist in the same slot (from different iterations).
// Shows the first one.
Operation *op = it->second[0].first;
int actual_time = it->second[0].second;

// Builds operation string: %result = op_name(%operand1, %operand2, ...).
std::string op_str;
llvm::raw_string_ostream op_stream(op_str);
mlir::OpPrintingFlags flags;

// Prints result (if exists).
if (op->getNumResults() > 0) {
op->getResult(0).printAsOperand(op_stream, flags);
op_stream << " = ";
}

// Prints operation name (removes "neura." prefix).
std::string op_name = op->getName().getStringRef().str();
if (op_name.rfind("neura.", 0) == 0) {
op_name = op_name.substr(6);
}
op_stream << op_name;

// Prints operands.
if (op->getNumOperands() > 0) {
op_stream << "(";
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
if (i > 0) op_stream << ", ";
op->getOperand(i).printAsOperand(op_stream, flags);
}
op_stream << ")";
}

// Adds time annotation if not in [0, II).
if (actual_time >= II) {
op_stream << " (t=" << actual_time << ")";
}

op_stream.flush();

// Truncates string if too long to fit in the cell.
if (op_str.length() > kCellWidth) {
op_str = op_str.substr(0, kCellWidth - 3) + "...";
}

// Pads to fixed width (kCellWidth chars).
os << op_str;
int padding = kCellWidth - op_str.length();
for (int i = 0; i < padding; ++i) os << " ";
} else {
// Renders empty cell.
for (int i = 0; i < kCellWidth; ++i) os << " ";
}
os << " | ";
}
os << "\n";
}

os << "\n=== Legend ===\n";
os << "- Table shows operations mapped to tiles (modulo II scheduling)\n";
os << "- Column headers: t%II=X means time slot X (t=X, X+II, X+2*II, ...)\n";
os << "- Operations with (t=Y) annotation are scheduled at actual time step Y\n";
os << "- Operations without annotation are scheduled at t=0 to t=" << (II-1) << "\n";
os << "=== End ===\n";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ DEFINE_BINARY_OP_PATTERN(FAdd, FAddOp)
DEFINE_BINARY_OP_PATTERN(FSub, FSubOp)
// Generates: FuseFMulConstantPattern.
DEFINE_BINARY_OP_PATTERN(FMul, FMulOp)
// Generates: FuseFDivConstantPattern.
DEFINE_BINARY_OP_PATTERN(FDiv, FDivOp)
// Generates: FuseICmpConstantPattern.
// Note: ICmpOp has a cmp_type attribute that is automatically preserved.
DEFINE_BINARY_OP_PATTERN(ICmp, ICmpOp)
Expand Down Expand Up @@ -587,6 +589,7 @@ struct FoldConstantPass
patterns.add<FuseFAddConstantPattern>(&getContext());
patterns.add<FuseFSubConstantPattern>(&getContext());
patterns.add<FuseFMulConstantPattern>(&getContext());
patterns.add<FuseFDivConstantPattern>(&getContext());
patterns.add<FuseFMaxConstantPattern>(&getContext());
patterns.add<FuseFMinConstantPattern>(&getContext());

Expand Down
Loading