11#include " NeuraDialect/Architecture/Architecture.h"
2+ #include " llvm/Support/raw_ostream.h"
3+
24#include < cassert>
35
46using namespace mlir ;
@@ -72,14 +74,15 @@ void Link::connect(Tile* src, Tile* dst) {
7274Architecture::Architecture (int width, int height) {
7375 const int num_tiles = width * height;
7476
75- tileStorage.reserve (num_tiles);
76- tiles.reserve (num_tiles);
77+ tile_storage.reserve (num_tiles);
7778
7879 for (int i = 0 ; i < width; ++i) {
7980 for (int j = 0 ; j < height; ++j) {
80- auto tile = std::make_unique<Tile>(i * width + j, i, j);
81- tiles.push_back (tile.get ());
82- tileStorage.push_back (std::move (tile));
81+ const int id = i * width + j;
82+ auto tile = std::make_unique<Tile>(id, i, j);
83+ id_to_tile[id] = tile.get ();
84+ coord_to_tile[{i, j}] = tile.get ();
85+ tile_storage.push_back (std::move (tile));
8386 }
8487 }
8588
@@ -92,39 +95,54 @@ Architecture::Architecture(int width, int height) {
9295 if (i > 0 ) {
9396 auto link_towards_left = std::make_unique<Link>(link_id++);
9497 link_towards_left->connect (tile, getTile (i - 1 , j));
98+ link_storage.push_back (std::move (link_towards_left));
9599 }
96100 if (i < width - 1 ) {
97101 auto link_towards_right = std::make_unique<Link>(link_id++);
98102 link_towards_right->connect (tile, getTile (i + 1 , j));
103+ link_storage.push_back (std::move (link_towards_right));
99104 }
100105 if (j > 0 ) {
101106 auto link_towards_down = std::make_unique<Link>(link_id++);
102107 link_towards_down->connect (tile, getTile (i, j - 1 ));
108+ link_storage.push_back (std::move (link_towards_down));
103109 }
104110 if (j < height - 1 ) {
105111 auto link_towards_up = std::make_unique<Link>(link_id++);
106112 link_towards_up->connect (tile, getTile (i, j + 1 ));
113+ link_storage.push_back (std::move (link_towards_up));
107114 }
108115 }
109116 }
110117}
111118
112119Tile* Architecture::getTile (int id) {
113- for (const auto &tile : tiles) {
114- if (tile->getId () == id) {
115- return tile;
116- }
117- }
118- assert (false && " Tile with given ID not found" );
119- return nullptr ;
120+ auto it = id_to_tile.find (id);
121+ assert (it != id_to_tile.end () && " Tile with given ID not found" );
122+ return it->second ;
120123}
121124
122125Tile* Architecture::getTile (int x, int y) {
123- for (const auto &tile : tiles) {
124- if (tile->getX () == x && tile->getY () == y) {
125- return tile;
126- }
126+ auto it = coord_to_tile.find ({x, y});
127+ assert (it != coord_to_tile.end () && " Tile with given coordinates not found" );
128+ return it->second ;
129+ }
130+
131+ std::vector<Tile*> Architecture::getAllTiles () const {
132+ std::vector<Tile*> result;
133+ for (auto &tile : tile_storage)
134+ result.push_back (tile.get ());
135+ return result;
136+ }
137+
138+ int Architecture::getNumTiles () const {
139+ return static_cast <int >(id_to_tile.size ());
140+ }
141+
142+ std::vector<Link*> Architecture::getAllLinks () const {
143+ std::vector<Link*> all_links;
144+ for (const auto &link : link_storage) {
145+ all_links.push_back (link.get ());
127146 }
128- assert (false && " Tile with given coordinates not found" );
129- return nullptr ;
147+ return all_links;
130148}
0 commit comments