Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/NeuraDialect/Architecture/Architecture.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ class Link : public BasicResource {

class Register : public BasicResource {
public:
Register(int id);
Register(int id, int local_id);

int getId() const override;

int getLocalId() const;
void setLocalId(int local_id);

std::string getType() const override { return "register"; }

Expand All @@ -280,6 +283,7 @@ class Register : public BasicResource {

private:
int id;
int local_id;
RegisterFile *register_file;
};

Expand Down
122 changes: 122 additions & 0 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,123 @@ struct LlvmReturnToNeuraReturn : public OpRewritePattern<LLVM::ReturnOp> {
}
};

struct LlvmXOrToNeuraOr : public OpRewritePattern<LLVM::XOrOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::XOrOp op,
PatternRewriter &rewriter) const override {
// Gets operands.
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Replaces with neura.or operation.
rewriter.replaceOpWithNewOp<neura::OrOp>(op, result_type, lhs, rhs);
return success();
}
};

struct LlvmAndToNeuraMul : public OpRewritePattern<LLVM::AndOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::AndOp op,
PatternRewriter &rewriter) const override {
// Gets operands.
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// For boolean AND, uses multiplication (a AND b = a * b for boolean values).
rewriter.replaceOpWithNewOp<neura::MulOp>(op, result_type, lhs, rhs);
return success();
}
};

struct LlvmFNegToNeuraFSub : public OpRewritePattern<LLVM::FNegOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::FNegOp op,
PatternRewriter &rewriter) const override {
// Gets operand.
Value operand = op.getOperand();
Type result_type = op.getType();
Location loc = op.getLoc();

// Creates constant 0.0.
auto zero_attr = rewriter.getFloatAttr(result_type, 0.0);
auto zero_const = rewriter.create<neura::ConstantOp>(loc, result_type, zero_attr);

// Replaces with 0.0 - operand.
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, zero_const, operand);
return success();
}
};

struct LlvmLShrToNeuraShl : public OpRewritePattern<LLVM::LShrOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::LShrOp op,
PatternRewriter &rewriter) const override {
// Gets operands.
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();
Location loc = op.getLoc();

// Implements logical right shift as left shift with negative amount.
// lshr(x, n) = shl(x, -n).
// Creates constant 0.
auto zero_attr = rewriter.getIntegerAttr(rhs.getType(), 0);
auto zero_const = rewriter.create<neura::ConstantOp>(loc, rhs.getType(), zero_attr);

// Negates the shift amount: -n = 0 - n.
auto neg_rhs = rewriter.create<neura::SubOp>(loc, rhs.getType(), zero_const, rhs);

// Replaces with shl(lhs, -rhs).
rewriter.replaceOpWithNewOp<neura::ShlOp>(op, result_type, lhs, neg_rhs);
return success();
}
};

struct LlvmSelectToNeuraOps : public OpRewritePattern<LLVM::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::SelectOp op,
PatternRewriter &rewriter) const override {
// Gets operands: condition, true_value, false_value.
Value condition = op.getCondition();
Value true_value = op.getTrueValue();
Value false_value = op.getFalseValue();
Type result_type = op.getType();
Location loc = op.getLoc();

// Implements: result = condition * true_value + (1 - condition) * false_value.
// For i1 (boolean), this works perfectly since i1 is 0 or 1.

// Step 1: Computes condition * true_value.
auto cond_times_true = rewriter.create<neura::MulOp>(
loc, result_type, condition, true_value);

// Step 2: Creates constant 1.
auto one_attr = rewriter.getIntegerAttr(result_type, 1);
auto one_const = rewriter.create<neura::ConstantOp>(loc, result_type, one_attr);

// Step 3: Computes (1 - condition).
auto not_condition = rewriter.create<neura::SubOp>(
loc, result_type, one_const, condition);

// Step 4: Computes (1 - condition) * false_value.
auto not_cond_times_false = rewriter.create<neura::MulOp>(
loc, result_type, not_condition, false_value);

// Step 5: Computes result = cond_times_true + not_cond_times_false.
rewriter.replaceOpWithNewOp<neura::AddOp>(
op, result_type, cond_times_true, not_cond_times_false);

return success();
}
};

struct FuncReturnToNeuraReturn : public OpRewritePattern<func::ReturnOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -752,6 +869,11 @@ struct LowerLlvmToNeuraPass
patterns.add<LlvmFDivToNeuraFDiv>(&getContext());
patterns.add<LlvmFPToSIToNeuraCast>(&getContext());
patterns.add<LlvmFMulAddToNeuraFMulFAdd>(&getContext());
patterns.add<LlvmXOrToNeuraOr>(&getContext());
patterns.add<LlvmAndToNeuraMul>(&getContext());
patterns.add<LlvmFNegToNeuraFSub>(&getContext());
patterns.add<LlvmLShrToNeuraShl>(&getContext());
patterns.add<LlvmSelectToNeuraOps>(&getContext());

FrozenRewritePatternSet frozen(std::move(patterns));

Expand Down
12 changes: 9 additions & 3 deletions lib/NeuraDialect/Architecture/Architecture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,14 @@ Tile *FunctionUnit::getTile() const { return this->tile; }
// Register
//===----------------------------------------------------------------------===//

Register::Register(int id) { this->id = id; }
Register::Register(int id, int local_id) : id(id), local_id(local_id) {}

int Register::getId() const { return id; }

int Register::getLocalId() const { return local_id; }

void Register::setLocalId(int local_id) { this->local_id = local_id; }

Tile *Register::getTile() const {
return this->register_file ? register_file->getTile() : nullptr;
}
Expand Down Expand Up @@ -344,10 +348,11 @@ void Architecture::createRegisterFileCluster(Tile *tile, int num_registers, int
RegisterFileCluster *register_file_cluster = new RegisterFileCluster(tile->getId());

// Creates registers as a register file.
int local_reg_id = 0;
for (int file_idx = 0; file_idx < k_num_regfiles_per_cluster; ++file_idx) {
RegisterFile *register_file = new RegisterFile(file_idx);
for (int reg_idx = 0; reg_idx < k_num_regs_per_regfile; ++reg_idx) {
Register *reg = new Register(reg_id++);
Register *reg = new Register(reg_id++, local_reg_id++);
register_file->addRegister(reg);
}
register_file_cluster->addRegisterFile(register_file);
Expand Down Expand Up @@ -388,10 +393,11 @@ void Architecture::recreateRegisterFileCluster(Tile *tile, int num_registers) {

// Creates registers with new capacity.
int reg_id = tile->getId() * 1000; // Use tile ID as base to avoid conflicts.
int local_reg_id = 0;
for (int file_idx = 0; file_idx < kNumRegfilesPerCluster; ++file_idx) {
RegisterFile *register_file = new RegisterFile(file_idx);
for (int reg_idx = 0; reg_idx < kNumRegsPerRegfile; ++reg_idx) {
Register *reg = new Register(reg_id++);
Register *reg = new Register(reg_id++, local_reg_id++);
register_file->addRegister(reg);
}
new_register_file_cluster->addRegisterFile(register_file);
Expand Down
10 changes: 9 additions & 1 deletion lib/NeuraDialect/Mapping/MappingState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,21 @@ void MappingState::encodeMappingState() {
mapping_entries.push_back(dict);
} else if (loc.resource->getKind() == ResourceKind::Register) {
kind_str = "register";
Register *reg = static_cast<Register *>(loc.resource);
int global_id = loc.resource->getId();
int local_register_id = reg->getLocalId();

auto dict = mlir::DictionaryAttr::get(
ctx, {mlir::NamedAttribute(mlir::StringAttr::get(ctx, "resource"),
mlir::StringAttr::get(ctx, kind_str)),
mlir::NamedAttribute(
mlir::StringAttr::get(ctx, "id"),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
loc.resource->getId())),
global_id)),
mlir::NamedAttribute(
mlir::StringAttr::get(ctx, "local_register_id"),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
local_register_id)),
mlir::NamedAttribute(
mlir::StringAttr::get(ctx, "time_step"),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
Expand Down
38 changes: 29 additions & 9 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct Tile {
struct ArrayConfig {
int columns;
int rows;
int compiled_ii = -1;
std::vector<Tile> cores;
};

Expand Down Expand Up @@ -109,8 +110,9 @@ static std::optional<int> getMappedRegId(Operation *op) {
auto resource_attr = dyn_cast_or_null<StringAttr>(location_dict.get("resource"));
if (!resource_attr) continue;
if (resource_attr.getValue() == "register" || resource_attr.getValue() == "reg") {
if (auto register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("id")))
return register_id.getInt();
if (auto local_register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("local_register_id"))) {
return local_register_id.getInt();
}
}
}
}
Expand Down Expand Up @@ -248,10 +250,10 @@ static SmallVector<RegStep, 4> collectRegSteps(Operation *op) {
auto resource_attr = dyn_cast_or_null<StringAttr>(location_dict.get("resource"));
if (!resource_attr) continue;
if (resource_attr.getValue() == "register" || resource_attr.getValue() == "reg") {
auto register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("id"));
auto local_register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("local_register_id"));
auto time_step = dyn_cast_or_null<IntegerAttr>(location_dict.get("time_step"));
if (!register_id || !time_step) continue;
steps.push_back({(int)register_id.getInt(), (int)time_step.getInt()});
if (!local_register_id || !time_step) continue;
steps.push_back({(int)local_register_id.getInt(), (int)time_step.getInt()});
}
}
}
Expand Down Expand Up @@ -317,6 +319,15 @@ struct GenerateCodePass
return {columns, rows};
}

int getCompiledII(func::FuncOp function) {
if (auto mapping_info = function->getAttrOfType<DictionaryAttr>("mapping_info")) {
if (auto compiled_ii = dyn_cast_or_null<IntegerAttr>(mapping_info.get("compiled_ii"))) {
return compiled_ii.getInt();
}
}
return -1;
}

// ---------- Single-walk indexing ----------.
// Do everything that needs walks in a single pass:.
// - record operation_placements.
Expand Down Expand Up @@ -571,8 +582,8 @@ struct GenerateCodePass
}
}

ArrayConfig buildArrayConfig(int columns, int rows) {
ArrayConfig config{columns, rows, {}};
ArrayConfig buildArrayConfig(int columns, int rows, int compiled_ii = -1) {
ArrayConfig config{columns, rows, compiled_ii, {}};
std::map<std::pair<int,int>, std::vector<Instruction>> tile_insts;

// Flattens and sorts by timesteps.
Expand Down Expand Up @@ -600,7 +611,11 @@ struct GenerateCodePass
llvm::raw_fd_ostream yaml_out("tmp-generated-instructions.yaml", ec);
if (ec) return;

yaml_out << "array_config:\n columns: " << config.columns << "\n rows: " << config.rows << "\n cores:\n";
yaml_out << "array_config:\n columns: " << config.columns << "\n rows: " << config.rows;
if (config.compiled_ii >= 0) {
yaml_out << "\n compiled_ii: " << config.compiled_ii;
}
yaml_out << "\n cores:\n";
for (const Tile &core : config.cores) {
yaml_out << " - column: " << core.col_idx << "\n row: " << core.row_idx
<< "\n core_id: \"" << core.core_id << "\"\n entries:\n";
Expand Down Expand Up @@ -657,6 +672,10 @@ struct GenerateCodePass
llvm::raw_fd_ostream asm_out("tmp-generated-instructions.asm", ec);
if (ec) return;

if (config.compiled_ii >= 0) {
asm_out << "# Compiled II: " << config.compiled_ii << "\n\n";
}

for (const Tile &core : config.cores) {
asm_out << "PE(" << core.col_idx << "," << core.row_idx << "):\n";

Expand Down Expand Up @@ -765,7 +784,8 @@ struct GenerateCodePass
expandMovImpl<true>(op, topo, reserve_to_phi_map);
logUnresolvedOperands();

ArrayConfig config = buildArrayConfig(columns, rows);
int compiled_ii = getCompiledII(func);
ArrayConfig config = buildArrayConfig(columns, rows, compiled_ii);
writeYAMLOutput(config);
writeASMOutput(config);
}
Expand Down
Loading