Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 211 additions & 3 deletions include/NeuraDialect/Architecture/Architecture.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#ifndef NEURA_ARCHITECTURE_H
#define NEURA_ARCHITECTURE_H

#include <cassert>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include <set>
#include <unordered_map>
#include <optional>
#include <memory>
#include <vector>

namespace mlir {
namespace neura {
Expand All @@ -15,6 +17,25 @@ namespace neura {
enum class ResourceKind {
Tile,
Link,
FunctionUnit,
Register,
RegisterFile,
RegisterFileCluster,
};

// Enum for function unit resource type.
enum class FunctionUnitKind {
FixedPointAdder,
FixedPointMultiplier,
CustomizableFunctionUnit,
};

// Enum for supported operation types.
enum OperationKind {
IAdd = 0,
IMul = 1,
FAdd = 2,
FMul = 3
};

//===----------------------------------------------------------------------===//
Expand All @@ -31,7 +52,81 @@ class BasicResource {

//===----------------------------------------------------------------------===//
// Forward declaration for use in Tile
class Tile;
class Link;
class FunctionUnit;
class Register;
class RegisterFile;
class RegisterFileCluster;

//===----------------------------------------------------------------------===//
// Function Unit.
//===----------------------------------------------------------------------===//

class FunctionUnit : public BasicResource {
public:
FunctionUnit(int id);

int getId() const override;
std::string getType() const override { return "function_unit"; }
ResourceKind getKind() const override { return ResourceKind::FunctionUnit; }

static bool classof(const BasicResource *res) {
return res && res->getKind() == ResourceKind::FunctionUnit;
}

Tile* getTile() const;

void setTile(Tile* tile);

std::set<OperationKind> getSupportedOperations() const {
return supported_operations;
}

bool canSupportOperation(OperationKind operation) const {
for (const auto &op : supported_operations) {
if (op == operation) {
return true;
}
}
return false;
}

protected:
std::set<OperationKind> supported_operations;

private:
int id;
Tile* tile;
};

class FixedPointAdder : public FunctionUnit {
public:
FixedPointAdder(int id) : FunctionUnit(id) {
supported_operations.insert(OperationKind::IAdd);
}
std::string getType() const override { return "fixed_point_adder"; }
ResourceKind getKind() const override { return ResourceKind::FunctionUnit; }
};

class FixedPointMultiplier : public FunctionUnit {
public:
FixedPointMultiplier(int id) : FunctionUnit(id) {
supported_operations.insert(OperationKind::IMul);
}
std::string getType() const override { return "fixed_point_multiplier"; }
ResourceKind getKind() const override { return ResourceKind::FunctionUnit; }
};

class CustomizableFunctionUnit : public FunctionUnit {
public:
CustomizableFunctionUnit(int id) : FunctionUnit(id) {}
std::string getType() const override { return "customizable_function_unit"; }
ResourceKind getKind() const override { return ResourceKind::FunctionUnit; }
void addSupportedOperation(OperationKind operation_kind) {
supported_operations.insert(operation_kind);
}
};

//===----------------------------------------------------------------------===//
// Tile
Expand Down Expand Up @@ -59,13 +154,33 @@ class Tile : public BasicResource {
const std::set<Link*>& getOutLinks() const;
const std::set<Link*>& getInLinks() const;

void addFunctionUnit(std::unique_ptr<FunctionUnit> func_unit) {
assert(func_unit && "Cannot add null function unit");
func_unit->setTile(this);
functional_unit_storage.push_back(std::move(func_unit));
functional_units.insert(functional_unit_storage.back().get());
}

bool canSupportOperation(OperationKind operation) const {
for (FunctionUnit *fu : functional_units) {
if (fu->canSupportOperation(operation)) {
return true;
}
}
// TODO: Check if the tile can support the operation based on its capabilities.
// @Jackcuii, https://github.com/coredac/dataflow/issues/82.
return true;
}

private:
int id;
int x, y;
std::set<Tile*> src_tiles;
std::set<Tile*> dst_tiles;
std::set<Link*> in_links;
std::set<Link*> out_links;
std::vector<std::unique_ptr<FunctionUnit>> functional_unit_storage; // Owns FUs.
std::set<FunctionUnit*> functional_units; // Non-owning, for fast lookup.
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -96,6 +211,99 @@ class Link : public BasicResource {
Tile* dst_tile;
};

//===----------------------------------------------------------------------===//
// Register
//===----------------------------------------------------------------------===//

class Register : public BasicResource {
public:
Register(int id);

int getId() const override;

std::string getType() const override { return "register"; }

ResourceKind getKind() const override { return ResourceKind::Register; }

static bool classof(const BasicResource *res) {
return res && res->getKind() == ResourceKind::Register;
}

Tile* getTile() const;

void setRegisterFile(RegisterFile* register_file);

RegisterFile* getRegisterFile() const;

private:
int id;
RegisterFile* register_file;
};

//===----------------------------------------------------------------------===//
// Register File
//===----------------------------------------------------------------------===//

class RegisterFile : public BasicResource {
public:
RegisterFile(int id);

int getId() const override;

std::string getType() const override { return "register_file"; }

ResourceKind getKind() const override { return ResourceKind::RegisterFile; }

static bool classof(const BasicResource *res) {
return res && res->getKind() == ResourceKind::RegisterFile;
}

Tile* getTile() const;

void setRegisterFileCluster(RegisterFileCluster* register_file_cluster);

void addRegister(Register* reg);

const std::map<int, Register*>& getRegisters() const;
RegisterFileCluster* getRegisterFileCluster() const;

private:
int id;
std::map<int, Register*> registers;
RegisterFileCluster* register_file_cluster = nullptr;
};

//===----------------------------------------------------------------------===//
// Register File Cluster
//===----------------------------------------------------------------------===//

class RegisterFileCluster : public BasicResource {
public:
RegisterFileCluster(int id);
int getId() const override;

std::string getType() const override { return "register_file_cluster"; }

ResourceKind getKind() const override { return ResourceKind::RegisterFileCluster; }

static bool classof(const BasicResource *res) {
return res && res->getKind() == ResourceKind::RegisterFileCluster;
}

Tile* getTile() const;
void setTile(Tile* tile);

void addRegisterFile(RegisterFile* register_file);
const std::map<int, RegisterFile*>& getRegisterFiles() const;

private:
int id;
Tile* tile;
std::map<int, RegisterFile*> register_files;
};

//===----------------------------------------------------------------------===//

struct PairHash {
std::size_t operator()(const std::pair<int, int> &coord) const {
return std::hash<int>()(coord.first) ^ (std::hash<int>()(coord.second) << 1);
Expand Down
101 changes: 101 additions & 0 deletions lib/NeuraDialect/Architecture/Architecture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
using namespace mlir;
using namespace mlir::neura;

//===----------------------------------------------------------------------===//
// Tile
//===----------------------------------------------------------------------===//

Tile::Tile(int id, int x, int y) {
this->id = id;
this->x = x;
this->y = y;

// TODO: Add function units based on architecture specs.
// @Jackcuii, https://github.com/coredac/dataflow/issues/82.
addFunctionUnit(std::make_unique<FixedPointAdder>(0));
}

int Tile::getId() const { return id; }
Expand All @@ -34,6 +42,10 @@ const std::set<Link *> &Tile::getOutLinks() const { return out_links; }

const std::set<Link *> &Tile::getInLinks() const { return in_links; }

//===----------------------------------------------------------------------===//
// Link
//===----------------------------------------------------------------------===//

Link::Link(int id) { this->id = id; }

int Link::getId() const { return id; }
Expand All @@ -49,6 +61,95 @@ void Link::connect(Tile *src, Tile *dst) {
src->linkDstTile(this, dst);
}

//===----------------------------------------------------------------------===//
// FunctionUnit
//===----------------------------------------------------------------------===//

FunctionUnit::FunctionUnit(int id) { this->id = id; }

int FunctionUnit::getId() const { return id; }

void FunctionUnit::setTile(Tile* tile) {
this->tile = tile;
}

Tile *FunctionUnit::getTile() const {
return this->tile;
}

//===----------------------------------------------------------------------===//
// Register
//===----------------------------------------------------------------------===//

Tile *Register::getTile() const {
return this->register_file ? register_file->getTile() : nullptr;
}

Register::Register(int id) { this->id = id; }

int Register::getId() const { return id; }

void Register::setRegisterFile(RegisterFile* register_file) {
this->register_file = register_file;
}

//===----------------------------------------------------------------------===//
// Register File
//===----------------------------------------------------------------------===//

RegisterFile::RegisterFile(int id) { this->id = id; }

int RegisterFile::getId() const { return id; }

Tile *RegisterFile::getTile() const {
return this->register_file_cluster ? register_file_cluster->getTile() : nullptr;
}

void RegisterFile::setRegisterFileCluster(RegisterFileCluster* register_file_cluster) {
this->register_file_cluster = register_file_cluster;
}

void RegisterFile::addRegister(Register* reg) {
registers[reg->getId()] = reg;
reg->setRegisterFile(this);
}

const std::map<int, Register*>& RegisterFile::getRegisters() const {
return this->registers;
}
//===----------------------------------------------------------------------===//
// Register File Cluster
//===----------------------------------------------------------------------===//

RegisterFileCluster* RegisterFile::getRegisterFileCluster() const {
return this->register_file_cluster;
}

RegisterFileCluster::RegisterFileCluster(int id) { this->id = id; }

int RegisterFileCluster::getId() const { return id; }

void RegisterFileCluster::setTile(Tile* tile) {
this->tile = tile;
}

Tile *RegisterFileCluster::getTile() const {
return this->tile;
}

void RegisterFileCluster::addRegisterFile(RegisterFile* register_file) {
register_files[register_file->getId()] = register_file;
register_file->setRegisterFileCluster(this);
}

const std::map<int, RegisterFile*>& RegisterFileCluster::getRegisterFiles() const {
return this->register_files;
}

//===----------------------------------------------------------------------===//
// Architecture
//===----------------------------------------------------------------------===//

Architecture::Architecture(int width, int height) {
const int num_tiles = width * height;

Expand Down
Loading