Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,35 +124,62 @@ static std::string getOpcode(Operation *op) {
if (opcode.rfind("neura.", 0) == 0) opcode = opcode.substr(6);
if (isConstant(op)) return "CONSTANT";
std::transform(opcode.begin(), opcode.end(), opcode.begin(), ::toupper);

// For comparison operations, appends the comparison type to the opcode.
Comment thread
n0thingNoob marked this conversation as resolved.
if (auto icmp_op = dyn_cast<ICmpOp>(op)) {
std::string cmp_type = icmp_op.getCmpType().str();
std::transform(cmp_type.begin(), cmp_type.end(), cmp_type.begin(), ::toupper);
return opcode + "_" + cmp_type;
}
if (auto fcmp_op = dyn_cast<FCmpOp>(op)) {
std::string cmp_type = fcmp_op.getCmpType().str();
std::transform(cmp_type.begin(), cmp_type.end(), cmp_type.begin(), ::toupper);
return opcode + "_" + cmp_type;
}

// For cast operations, appends the cast type to the opcode.
if (auto cast_op = dyn_cast<CastOp>(op)) {
std::string cast_type = cast_op.getCastType().str();
std::transform(cast_type.begin(), cast_type.end(), cast_type.begin(), ::toupper);
return opcode + "_" + cast_type;
}

return opcode;
}

// Extracts constant literal from an attribute.
// Returns formatted string like "#10" or "#3.0", or empty string if not found.
Comment thread
n0thingNoob marked this conversation as resolved.
static std::string extractConstantLiteralFromAttr(Attribute attr) {
if (!attr) return "";

if (auto integer_attr = dyn_cast<IntegerAttr>(attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(attr))
return "#" + std::to_string(float_attr.getValueAsDouble());

return "";
}

// Literals for CONSTANT operations, e.g. "#10" / "#0" / "#3.0".
static std::string getConstantLiteral(Operation *op) {
if (isConstant(op)) {
if (auto value_attr = op->getAttr("value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
//TODO: Issue #154: handle argument situations.
// if (auto string_attr = dyn_cast<StringAttr>(value_attr)) {
// std::string value = string_attr.getValue().str();
// return value;
// }
std::string result = extractConstantLiteralFromAttr(value_attr);
if (!result.empty()) return result;
}
return "#0";
}

// Checks for constant_value attribute in non-CONSTANT operations.
if (auto constant_value_attr = op->getAttr("constant_value")) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_value_attr))
return "#" + std::to_string(integer_attr.getInt());
if (auto float_attr = dyn_cast<FloatAttr>(constant_value_attr))
return "#" + std::to_string(float_attr.getValueAsDouble());
//TODO: Issue #154: handle argument situations.
// if (auto string_attr = dyn_cast<StringAttr>(constant_value_attr))
// return string_attr.getValue().str();
std::string result = extractConstantLiteralFromAttr(constant_value_attr);
if (!result.empty()) return result;
}

// Checks for rhs_value attribute (for binary operations with constant RHS).
if (auto rhs_value_attr = op->getAttr("rhs_value")) {
std::string result = extractConstantLiteralFromAttr(rhs_value_attr);
if (!result.empty()) return result;
}

return "";
Expand Down Expand Up @@ -368,12 +395,23 @@ struct GenerateCodePass
// Checks if operation has constant_value attribute (for non-CONSTANT operations).
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
} else {
// Handles normal operands.
// Handles normal operands, including operations with rhs_value attribute.
SmallVector<Value> operands; operands.reserve(op->getNumOperands());

// Processes actual Value operands (if any).
for (Value v : op->getOperands()) {
operands.push_back(v);
inst.src_operands.emplace_back("UNRESOLVED", "RED");
}

// Handles cases where binary operations have the RHS constant stored as an attribute.
if (auto rhs_value_attr = op->getAttr("rhs_value")) {
std::string rhs_literal = extractConstantLiteralFromAttr(rhs_value_attr);
if (!rhs_literal.empty()) {
inst.src_operands.emplace_back(rhs_literal, "RED");
}
}

operation_to_operands[op] = std::move(operands);
}

Expand Down
8 changes: 8 additions & 0 deletions test/e2e/fir/fir_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,16 @@
// YAML: instructions:
// YAML: - opcode: "GRANT_ONCE"
// YAML: - opcode: "RETURN"
// YAML: - opcode: "ICMP_EQ"

// ASM: PE(0,1):
// ASM-NEXT: {
// ASM-NEXT: GRANT_ONCE, [#0] -> [$0]
// ASM-NEXT: } (t=3)
// ASM: PE(2,2):
// ASM-NEXT: {
// ASM-NEXT: DATA_MOV, [WEST, RED] -> [$0]
// ASM-NEXT: } (t=2)
// ASM-NEXT: {
// ASM-NEXT: ICMP_EQ, [EAST, RED], [#32] -> [$0], [WEST, RED]
// ASM-NEXT: } (t=3)
2 changes: 1 addition & 1 deletion test/e2e/histogram/histogram_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
// YAML: instructions:
// YAML: - opcode: "GRANT_ONCE"
// YAML: - opcode: "FDIV"
// YAML: - opcode: "CAST"
// YAML: - opcode: "CAST_FPTOSI"

// ASM: PE(2,2):
// ASM-NEXT: {
Expand Down
3 changes: 3 additions & 0 deletions test/e2e/relu/relu_kernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@
// YAML: compiled_ii: 5
// YAML: instructions:
// YAML: - opcode: "DATA_MOV"
// YAML: - opcode: "CAST_TRUNC"
// YAML: - opcode: "ICMP_EQ"
// YAML: - opcode: "ICMP_SGE"

// ASM: PE(2,1):
// ASM-NEXT: {
Expand Down