Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/NeuraDialect/Architecture/Architecture.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,11 @@ class Link : public BasicResource {

class Register : public BasicResource {
public:
Register(int id);
Register(int global_id, int per_tile_id);

int getId() const override;

int getPerTileId() const;

std::string getType() const override { return "register"; }

Expand All @@ -279,7 +281,8 @@ class Register : public BasicResource {
RegisterFile *getRegisterFile() const;

private:
int id;
int global_id;
int per_tile_id;
RegisterFile *register_file;
};

Expand Down Expand Up @@ -406,8 +409,7 @@ class Architecture {
void applyTileOverrides(const std::vector<TileOverride>& tile_overrides);
void createLinks(const LinkDefaults& link_defaults, BaseTopology base_topology);
void applyLinkOverrides(const std::vector<LinkOverride>& link_overrides);
void createRegisterFileCluster(Tile *tile, int num_registers, int &reg_id);
void recreateRegisterFileCluster(Tile *tile, int num_registers);
void createRegisterFileCluster(Tile *tile, int num_registers, int &num_already_assigned_global_registers, int global_id_start = -1);
bool linkExists(Tile *src_tile, Tile *dst_tile);

// Helper methods for creating different topology links.
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ struct LlvmReturnToNeuraReturn : public OpRewritePattern<LLVM::ReturnOp> {
}
};


struct FuncReturnToNeuraReturn : public OpRewritePattern<func::ReturnOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down
62 changes: 25 additions & 37 deletions lib/NeuraDialect/Architecture/Architecture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,11 @@ Tile *FunctionUnit::getTile() const { return this->tile; }
// Register
//===----------------------------------------------------------------------===//

Register::Register(int id) { this->id = id; }
Register::Register(int global_id, int per_tile_id) : global_id(global_id), per_tile_id(per_tile_id) {}

int Register::getId() const { return id; }
int Register::getId() const { return global_id; }

int Register::getPerTileId() const { return per_tile_id; }

Tile *Register::getTile() const {
return this->register_file ? register_file->getTile() : nullptr;
Expand Down Expand Up @@ -337,18 +339,24 @@ void Architecture::initializeTiles(int per_cgra_rows, int per_cgra_columns) {
}

// Creates register file cluster for a tile.
void Architecture::createRegisterFileCluster(Tile *tile, int num_registers, int &reg_id) {
void Architecture::createRegisterFileCluster(Tile *tile, int num_registers, int &num_already_assigned_global_registers, int global_id_start) {
const int k_num_regs_per_regfile = 8; // Keep this fixed for now.
const int k_num_regfiles_per_cluster = num_registers / k_num_regs_per_regfile;

// If global_id_start is specified, ensures it doesn't go backwards.
if (global_id_start >= 0) {
num_already_assigned_global_registers = std::max(num_already_assigned_global_registers, global_id_start);
}

RegisterFileCluster *register_file_cluster = new RegisterFileCluster(tile->getId());

// Creates registers as a register file.
int local_reg_id = 0;
for (int file_idx = 0; file_idx < k_num_regfiles_per_cluster; ++file_idx) {
RegisterFile *register_file = new RegisterFile(file_idx);
for (int reg_idx = 0; reg_idx < k_num_regs_per_regfile; ++reg_idx) {
Register *reg = new Register(reg_id++);
register_file->addRegister(reg);
for (int reg = 0; reg < k_num_regs_per_regfile; ++reg) {
Register *new_register = new Register(num_already_assigned_global_registers++, local_reg_id++);
register_file->addRegister(new_register);
}
register_file_cluster->addRegisterFile(register_file);
}
Expand All @@ -358,48 +366,20 @@ void Architecture::createRegisterFileCluster(Tile *tile, int num_registers, int

// Configures default tile settings.
void Architecture::configureDefaultTileSettings(const TileDefaults& tile_defaults) {
int reg_id = 0;
int num_already_assigned_global_registers = 0;
for (int y = 0; y < getPerCgraRows(); ++y) {
for (int x = 0; x < getPerCgraColumns(); ++x) {
Tile *tile = getTile(x, y);

// Creates register file cluster with default capacity.
createRegisterFileCluster(tile, tile_defaults.num_registers, reg_id);
createRegisterFileCluster(tile, tile_defaults.num_registers, num_already_assigned_global_registers);

// Configures function units based on tile_defaults.operations.
configureTileFunctionUnits(tile, tile_defaults.operations);
}
}
}

// Recreates register file cluster with new capacity.
void Architecture::recreateRegisterFileCluster(Tile *tile, int num_registers) {
constexpr int kNumRegsPerRegfile = 8; // Keep this fixed for now.
const int kNumRegfilesPerCluster = num_registers / kNumRegsPerRegfile;

// Removes existing register file cluster.
if (tile->getRegisterFileCluster()) {
delete tile->getRegisterFileCluster();
}

// Creates new register file cluster with override capacity.
RegisterFileCluster *new_register_file_cluster =
new RegisterFileCluster(tile->getId());

// Creates registers with new capacity.
int reg_id = tile->getId() * 1000; // Use tile ID as base to avoid conflicts.
for (int file_idx = 0; file_idx < kNumRegfilesPerCluster; ++file_idx) {
RegisterFile *register_file = new RegisterFile(file_idx);
for (int reg_idx = 0; reg_idx < kNumRegsPerRegfile; ++reg_idx) {
Register *reg = new Register(reg_id++);
register_file->addRegister(reg);
}
new_register_file_cluster->addRegisterFile(register_file);
}

// Adds new register file cluster to the tile.
tile->addRegisterFileCluster(new_register_file_cluster);
}

// Applies tile overrides to modify specific tiles.
void Architecture::applyTileOverrides(const std::vector<TileOverride>& tile_overrides) {
Expand All @@ -417,7 +397,15 @@ void Architecture::applyTileOverrides(const std::vector<TileOverride>& tile_over

// Overrides num_registers if specified.
if (override.num_registers > 0) {
recreateRegisterFileCluster(tile, override.num_registers);
// Removes existing register file cluster.
if (tile->getRegisterFileCluster()) {
delete tile->getRegisterFileCluster();
}

// Creates new register file cluster with override capacity.
// Uses tile ID as base to avoid conflicts with existing registers.
int dummy_ref = 0; // Not used when global_id_start is specified.
createRegisterFileCluster(tile, override.num_registers, dummy_ref, tile->getId() * 1000);
}
}
}
Expand Down
10 changes: 9 additions & 1 deletion lib/NeuraDialect/Mapping/MappingState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,21 @@ void MappingState::encodeMappingState() {
mapping_entries.push_back(dict);
} else if (loc.resource->getKind() == ResourceKind::Register) {
kind_str = "register";
Register *reg = static_cast<Register *>(loc.resource);
int global_id = loc.resource->getId();
int per_tile_register_id = reg->getPerTileId();

auto dict = mlir::DictionaryAttr::get(
ctx, {mlir::NamedAttribute(mlir::StringAttr::get(ctx, "resource"),
mlir::StringAttr::get(ctx, kind_str)),
mlir::NamedAttribute(
mlir::StringAttr::get(ctx, "id"),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
loc.resource->getId())),
global_id)),
mlir::NamedAttribute(
mlir::StringAttr::get(ctx, "per_tile_register_id"),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
per_tile_register_id)),
mlir::NamedAttribute(
mlir::StringAttr::get(ctx, "time_step"),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
Expand Down
38 changes: 29 additions & 9 deletions lib/NeuraDialect/Transforms/GenerateCodePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct Tile {
struct ArrayConfig {
int columns;
int rows;
int compiled_ii = -1;
std::vector<Tile> cores;
};

Expand Down Expand Up @@ -109,8 +110,9 @@ static std::optional<int> getMappedRegId(Operation *op) {
auto resource_attr = dyn_cast_or_null<StringAttr>(location_dict.get("resource"));
if (!resource_attr) continue;
if (resource_attr.getValue() == "register" || resource_attr.getValue() == "reg") {
if (auto register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("id")))
return register_id.getInt();
if (auto per_tile_register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("per_tile_register_id"))) {
return per_tile_register_id.getInt();
}
}
}
}
Expand Down Expand Up @@ -248,10 +250,10 @@ static SmallVector<RegStep, 4> collectRegSteps(Operation *op) {
auto resource_attr = dyn_cast_or_null<StringAttr>(location_dict.get("resource"));
if (!resource_attr) continue;
if (resource_attr.getValue() == "register" || resource_attr.getValue() == "reg") {
auto register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("id"));
auto per_tile_register_id = dyn_cast_or_null<IntegerAttr>(location_dict.get("per_tile_register_id"));
auto time_step = dyn_cast_or_null<IntegerAttr>(location_dict.get("time_step"));
if (!register_id || !time_step) continue;
steps.push_back({(int)register_id.getInt(), (int)time_step.getInt()});
if (!per_tile_register_id || !time_step) continue;
steps.push_back({(int)per_tile_register_id.getInt(), (int)time_step.getInt()});
}
}
}
Expand Down Expand Up @@ -317,6 +319,15 @@ struct GenerateCodePass
return {columns, rows};
}

int getCompiledII(func::FuncOp function) {
if (auto mapping_info = function->getAttrOfType<DictionaryAttr>("mapping_info")) {
if (auto compiled_ii = dyn_cast_or_null<IntegerAttr>(mapping_info.get("compiled_ii"))) {
return compiled_ii.getInt();
}
}
return -1;
}

// ---------- Single-walk indexing ----------.
// Do everything that needs walks in a single pass:.
// - record operation_placements.
Expand Down Expand Up @@ -571,8 +582,8 @@ struct GenerateCodePass
}
}

ArrayConfig buildArrayConfig(int columns, int rows) {
ArrayConfig config{columns, rows, {}};
ArrayConfig buildArrayConfig(int columns, int rows, int compiled_ii = -1) {
ArrayConfig config{columns, rows, compiled_ii, {}};
std::map<std::pair<int,int>, std::vector<Instruction>> tile_insts;

// Flattens and sorts by timesteps.
Expand Down Expand Up @@ -600,7 +611,11 @@ struct GenerateCodePass
llvm::raw_fd_ostream yaml_out("tmp-generated-instructions.yaml", ec);
if (ec) return;

yaml_out << "array_config:\n columns: " << config.columns << "\n rows: " << config.rows << "\n cores:\n";
yaml_out << "array_config:\n columns: " << config.columns << "\n rows: " << config.rows;
if (config.compiled_ii >= 0) {
yaml_out << "\n compiled_ii: " << config.compiled_ii;
}
yaml_out << "\n cores:\n";
for (const Tile &core : config.cores) {
yaml_out << " - column: " << core.col_idx << "\n row: " << core.row_idx
<< "\n core_id: \"" << core.core_id << "\"\n entries:\n";
Expand Down Expand Up @@ -657,6 +672,10 @@ struct GenerateCodePass
llvm::raw_fd_ostream asm_out("tmp-generated-instructions.asm", ec);
if (ec) return;

if (config.compiled_ii >= 0) {
asm_out << "# Compiled II: " << config.compiled_ii << "\n\n";
}

for (const Tile &core : config.cores) {
asm_out << "PE(" << core.col_idx << "," << core.row_idx << "):\n";

Expand Down Expand Up @@ -765,7 +784,8 @@ struct GenerateCodePass
expandMovImpl<true>(op, topo, reserve_to_phi_map);
logUnresolvedOperands();

ArrayConfig config = buildArrayConfig(columns, rows);
int compiled_ii = getCompiledII(func);
ArrayConfig config = buildArrayConfig(columns, rows, compiled_ii);
writeYAMLOutput(config);
writeASMOutput(config);
}
Expand Down
Loading