diff --git a/examples/addconst/addconst.h b/examples/addconst/addconst.h index b5e14df3..c2692745 100644 --- a/examples/addconst/addconst.h +++ b/examples/addconst/addconst.h @@ -10,7 +10,7 @@ class AddConst : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "AddConst"; } diff --git a/examples/idct/idct.h b/examples/idct/idct.h index 3d48be3d..4617e98d 100644 --- a/examples/idct/idct.h +++ b/examples/idct/idct.h @@ -13,7 +13,7 @@ static const int32_t kSIZE = kDIM * kDIM; class IDCT : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "IDCT"; } diff --git a/examples/matrixmul2/matrixmul2.h b/examples/matrixmul2/matrixmul2.h index a11ca153..7fee9313 100644 --- a/examples/matrixmul2/matrixmul2.h +++ b/examples/matrixmul2/matrixmul2.h @@ -10,7 +10,7 @@ class MatrixMul2 : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "MatrixMul2"; } diff --git a/examples/movingsum/movingsum.h b/examples/movingsum/movingsum.h index f8a87935..34350a57 100644 --- a/examples/movingsum/movingsum.h +++ b/examples/movingsum/movingsum.h @@ -10,7 +10,7 @@ class MovingSum : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "MovingSum"; } diff --git a/examples/muxmul/muxmul.h b/examples/muxmul/muxmul.h index c7745b5f..92f20bb4 100644 --- a/examples/muxmul/muxmul.h +++ b/examples/muxmul/muxmul.h @@ -10,7 +10,7 @@ class MuxMul : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "MuxMul"; } diff --git a/examples/polynomial2/polynomial2.h b/examples/polynomial2/polynomial2.h index e32d5961..98414da8 100644 --- a/examples/polynomial2/polynomial2.h +++ b/examples/polynomial2/polynomial2.h @@ -10,7 +10,7 @@ class Polynomial2 : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "Polynomial2"; } diff --git a/examples/polynomial2_inst/add_int_2_mul_int3.json b/examples/polynomial2_inst/add_int_2_mul_int3.json new file mode 100644 index 00000000..da7ea74b --- /dev/null +++ b/examples/polynomial2_inst/add_int_2_mul_int3.json @@ -0,0 +1,4 @@ +{ + "ADD_INT": 2, + "MUL_INT": 3 +} diff --git a/examples/polynomial2_inst/polynomial2_inst.cpp b/examples/polynomial2_inst/polynomial2_inst.cpp new file mode 100644 index 00000000..cea74f6e --- /dev/null +++ b/examples/polynomial2_inst/polynomial2_inst.cpp @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// +// Part of the Utopia HLS Project, under the Apache License v2.0 +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2024 ISP RAS (http://www.ispras.ru) +// +//===----------------------------------------------------------------------===// + +#include "polynomial2_inst.h" + +#include + +std::unique_ptr start() { + Polynomial2Inst *kernel = new Polynomial2Inst(); + return std::unique_ptr(kernel); +} diff --git a/examples/polynomial2_inst/polynomial2_inst.h b/examples/polynomial2_inst/polynomial2_inst.h new file mode 100644 index 00000000..57f8dfdf --- /dev/null +++ b/examples/polynomial2_inst/polynomial2_inst.h @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// +// Part of the Utopia HLS Project, under the Apache License v2.0 +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2024 ISP RAS (http://www.ispras.ru) +// +//===----------------------------------------------------------------------===// + +#include "dfcxx/DFCXX.h" + +class Polynomial2 : public dfcxx::Kernel { +public: + std::string_view getName() const override { + return "Polynomial2"; + } + + ~Polynomial2() override = default; + + Polynomial2() : dfcxx::Kernel() { + using dfcxx::DFType; + using dfcxx::DFVariable; + + const DFType &type = dfUInt(32); + DFVariable x = io.input("x", type); + DFVariable squared = x * x; + DFVariable squaredPlusX = squared + x; + DFVariable result = squaredPlusX + x; + DFVariable out = io.output("out", type); + out.connect(result); + } +}; + +class Polynomial2Inst : public dfcxx::Kernel { +public: + std::string_view getName() const override { + return "Polynomial2Inst"; + } + + ~Polynomial2Inst() override = default; + + Polynomial2Inst() : dfcxx::Kernel() { + using dfcxx::DFType; + using dfcxx::DFVariable; + + const DFType &type = dfUInt(32); + DFVariable x = io.input("x", type); + DFVariable intermediate = io.newStream(type); + instance({ + {x, "x"}, + {intermediate, "out"} + }); + DFVariable out = io.output("out", type); + out.connect(intermediate); + } +}; diff --git a/examples/polynomial2_inst/sim.txt b/examples/polynomial2_inst/sim.txt new file mode 100644 index 00000000..1e7620b8 --- /dev/null +++ b/examples/polynomial2_inst/sim.txt @@ -0,0 +1,5 @@ +x 0x32 + +x 0x45 + +x 0x56 diff --git a/examples/scalar3/scalar3.h b/examples/scalar3/scalar3.h index 946d22ee..2b50b479 100644 --- a/examples/scalar3/scalar3.h +++ b/examples/scalar3/scalar3.h @@ -10,7 +10,7 @@ class Scalar3 : public dfcxx::Kernel { public: - std::string_view getName() override { + std::string_view getName() const override { return "Scalar3"; } diff --git a/src/main.cpp b/src/main.cpp index 79cf2eb3..dd754d5f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -44,6 +44,7 @@ struct SimContext { int hlsMain(const HlsContext &context) { auto kernel = start(); + if (!kernel->check()) { return 1; } bool useASAP = context.options.asapScheduler; return !kernel->compile(context.options.latencyCfg, context.options.outNames, diff --git a/src/model/dfcxx/include/dfcxx/graph.h b/src/model/dfcxx/include/dfcxx/graph.h index c41ffa1d..ee1bcc96 100644 --- a/src/model/dfcxx/include/dfcxx/graph.h +++ b/src/model/dfcxx/include/dfcxx/graph.h @@ -13,6 +13,8 @@ #include "dfcxx/node.h" #include "dfcxx/vars/var.h" +#include +#include #include #include #include @@ -21,6 +23,7 @@ namespace dfcxx { class Graph { private: + std::unordered_map nameMap; std::unordered_set nodes; std::unordered_set startNodes; std::unordered_map> inputs; @@ -38,12 +41,24 @@ class Graph { const std::unordered_map &getConnections() const; - Node findNode(DFVariableImpl *var); - void addNode(DFVariableImpl *var, OpType type, NodeData data); void addChannel(DFVariableImpl *source, DFVariableImpl *target, unsigned opInd, bool connect); + + void transferFrom(Graph &&graph); + + void resetNodeName(const std::string &name); + + void deleteNode(Node node); + + void rebindInput(Node source, Node input, Graph &graph); + + Node rebindOutput(Node output, Node target, Graph &graph); + + Node findNode(const std::string &name); + + Node findNode(DFVariableImpl *var); }; } // namespace dfcxx diff --git a/src/model/dfcxx/include/dfcxx/io.h b/src/model/dfcxx/include/dfcxx/io.h index 5ec0e66e..412b8444 100644 --- a/src/model/dfcxx/include/dfcxx/io.h +++ b/src/model/dfcxx/include/dfcxx/io.h @@ -28,6 +28,10 @@ class IO { DFVariable inputScalar(const std::string &name, const DFType &type); + DFVariable newStream(const DFType &type); + + DFVariable newScalar(const DFType &type); + DFVariable output(const std::string &name, const DFType &type); DFVariable outputScalar(const std::string &name, const DFType &type); diff --git a/src/model/dfcxx/include/dfcxx/kernel.h b/src/model/dfcxx/include/dfcxx/kernel.h index a64fe114..82a13a39 100644 --- a/src/model/dfcxx/include/dfcxx/kernel.h +++ b/src/model/dfcxx/include/dfcxx/kernel.h @@ -18,9 +18,11 @@ #include "dfcxx/types/type.h" #include "dfcxx/vars/var.h" +#include #include #include #include +#include #include // This forward declaration is needed to avoid @@ -37,6 +39,12 @@ class Kernel { bool compileDot(llvm::raw_fd_ostream *stream); + void rebindInput(DFVariable source, Node input, Kernel &kern); + + DFVariable rebindOutput(Node output, DFVariable target, Kernel &kern); + + void deleteNode(Node node); + protected: IO io; Offset offset; @@ -51,12 +59,31 @@ class Kernel { DFType dfBool(); + using IOBinding = std::pair; + + template + void instance(std::initializer_list bindings, Args && ...args) { + Kern kern(std::forward(args)...); + + for (auto &binding: bindings) { + Node node = kern.meta.graph.findNode(binding.second); + kern.meta.graph.resetNodeName(binding.second); + if (node.type == OpType::IN) { + rebindInput(binding.first, node, kern); + } else { + binding.first = rebindOutput(node, binding.first, kern); + } + } + + meta.transferFrom(std::move(kern.meta)); + } + Kernel(); public: virtual ~Kernel() = default; - virtual std::string_view getName() = 0; + virtual std::string_view getName() const = 0; const Graph &getGraph() const; @@ -71,6 +98,11 @@ class Kernel { bool simulate(const std::string &inDataPath, const std::string &outFilePath); + bool check() const; + +// Checker methods. +private: + bool checkValidNodes() const; }; } // namespace dfcxx diff --git a/src/model/dfcxx/include/dfcxx/kernmeta.h b/src/model/dfcxx/include/dfcxx/kernmeta.h index a1d3715a..2d4520e9 100644 --- a/src/model/dfcxx/include/dfcxx/kernmeta.h +++ b/src/model/dfcxx/include/dfcxx/kernmeta.h @@ -25,6 +25,8 @@ struct KernMeta { KernMeta() = default; KernMeta(const KernMeta &) = delete; ~KernMeta() = default; + + void transferFrom(KernMeta &&meta); }; } // namespace dfcxx diff --git a/src/model/dfcxx/include/dfcxx/kernstorage.h b/src/model/dfcxx/include/dfcxx/kernstorage.h index 3fc43316..c5ae5868 100644 --- a/src/model/dfcxx/include/dfcxx/kernstorage.h +++ b/src/model/dfcxx/include/dfcxx/kernstorage.h @@ -26,7 +26,11 @@ class KernStorage { DFVariableImpl *addVariable(DFVariableImpl *var); + void deleteVariable(DFVariableImpl *var); + ~KernStorage(); + + void transferFrom(KernStorage &&storage); }; } // namespace dfcxx diff --git a/src/model/dfcxx/include/dfcxx/node.h b/src/model/dfcxx/include/dfcxx/node.h index 9c7eb85f..59cd1a9f 100644 --- a/src/model/dfcxx/include/dfcxx/node.h +++ b/src/model/dfcxx/include/dfcxx/node.h @@ -14,7 +14,8 @@ namespace dfcxx { enum OpType : uint8_t { - OFFSET = 0, + NONE = 0, // Is not allowed in a fully constructed kernel. + OFFSET, IN, OUT, CONST, @@ -53,6 +54,7 @@ struct Node { Node(DFVariableImpl *var, OpType type, NodeData data); bool operator==(const Node &node) const; + bool operator!=(const Node &node) const { return !(*this == node); } }; } // namespace dfcxx diff --git a/src/model/dfcxx/include/dfcxx/vars/var.h b/src/model/dfcxx/include/dfcxx/vars/var.h index 6ab75330..9b98bd99 100644 --- a/src/model/dfcxx/include/dfcxx/vars/var.h +++ b/src/model/dfcxx/include/dfcxx/vars/var.h @@ -50,6 +50,8 @@ class DFVariableImpl { std::string_view getName() const; + void resetName(); + IODirection getDirection() const; const KernMeta &getMeta() const; diff --git a/src/model/dfcxx/lib/dfcxx/CMakeLists.txt b/src/model/dfcxx/lib/dfcxx/CMakeLists.txt index 2d758039..03b765f9 100644 --- a/src/model/dfcxx/lib/dfcxx/CMakeLists.txt +++ b/src/model/dfcxx/lib/dfcxx/CMakeLists.txt @@ -18,6 +18,7 @@ set(SOURCES ${VAR_BUILDERS_SOURCES} ${TYPE_BUILDERS_SOURCES} converter.cpp + kernmeta.cpp kernstorage.cpp io.cpp offset.cpp diff --git a/src/model/dfcxx/lib/dfcxx/IRbuilders/builder.cpp b/src/model/dfcxx/lib/dfcxx/IRbuilders/builder.cpp index 2dffc351..f173f4f3 100644 --- a/src/model/dfcxx/lib/dfcxx/IRbuilders/builder.cpp +++ b/src/model/dfcxx/lib/dfcxx/IRbuilders/builder.cpp @@ -275,7 +275,7 @@ void DFCIRBuilder::translate(Node node, const Graph &graph, } default: { // TODO: Add proper logging: https://github.com/ispras/utopia-hls/issues/13 - std::cout << "[ERROR] Unknown node type id: " << node.type << std::endl; + std::cout << "[ERROR] Unknown/unsupported node type id: " << node.type << std::endl; assert(false); }; } @@ -283,8 +283,9 @@ void DFCIRBuilder::translate(Node node, const Graph &graph, map[node] = newOp->getResult(0); auto &connections = graph.getConnections(); - if (connections.find(node) != connections.end()) { - auto conSrc = connections.at(node).source; + auto it = connections.find(node); + if (it != connections.end()) { + auto conSrc = it->second.source; builder.create(loc, map[node], map[conSrc]); } } diff --git a/src/model/dfcxx/lib/dfcxx/graph.cpp b/src/model/dfcxx/lib/dfcxx/graph.cpp index 447a5e79..67d5322f 100644 --- a/src/model/dfcxx/lib/dfcxx/graph.cpp +++ b/src/model/dfcxx/lib/dfcxx/graph.cpp @@ -9,6 +9,7 @@ #include "dfcxx/graph.h" #include +#include namespace dfcxx { @@ -32,6 +33,14 @@ const std::unordered_map &Graph::getConnections() const { return connections; } +Node Graph::findNode(const std::string &name) { + auto it = nameMap.find(name); + if (it == nameMap.end()) { + throw new std::invalid_argument("Non-existent node with name: " + name); + } + return it->second; +} + Node Graph::findNode(DFVariableImpl *var) { return *std::find_if(nodes.begin(), nodes.end(), [&](const Node &node) { return node.var == var; }); @@ -42,6 +51,12 @@ void Graph::addNode(DFVariableImpl *var, OpType type, NodeData data) { if (type == IN || type == CONST) { startNodes.emplace(var, type, data); } + + auto name = node.first->var->getName(); + if (!name.empty()) { + nameMap[name] = *(node.first); + } + // The following lines create empty channel vectors // for new nodes. This allows to use .at() on unconnected // nodes without getting an exception. @@ -62,4 +77,77 @@ void Graph::addChannel(DFVariableImpl *source, DFVariableImpl *target, } } +void Graph::transferFrom(Graph &&graph) { + nodes.merge(std::move(graph.nodes)); + inputs.merge(std::move(graph.inputs)); + outputs.merge(std::move(graph.outputs)); + connections.merge(std::move(graph.connections)); +} + +void Graph::resetNodeName(const std::string &name) { + auto it = nameMap.find(name); + if (it == nameMap.end()) { + throw new std::invalid_argument("Non-existent node with name: " + name); + } + DFVariableImpl *ptr = it->second.var; + nameMap.erase(it); + ptr->resetName(); +} + +void Graph::deleteNode(Node node) { + connections.erase(node); + inputs.erase(node); + outputs.erase(node); + startNodes.erase(node); + nodes.erase(node); +} + +void Graph::rebindInput(Node source, Node input, Graph &graph) { + auto &conns = graph.connections; + for (auto &out: graph.outputs[input]) { + for (auto &in: graph.inputs[out.target]) { + if (in.source == input && out == in) { + in.source = source; + outputs[source].push_back(in); + break; + } + } + auto it = conns.find(out.target); + if (it != conns.end() && it->second.source == input) { + it->second.source = source; + } + } +} + +Node Graph::rebindOutput(Node output, Node target, Graph &graph) { + auto &inSrc = graph.inputs[output].front().source; + auto &outs = graph.outputs[inSrc]; + for (auto it = outs.begin(); it != outs.end(); ++it) { + if (it->target != output) { continue; } + if (target.type == OpType::NONE) { + outs.erase(it); + for (auto &out: outputs[target]) { + for (auto &in: inputs[out.target]) { + if (in.source == target && out == in) { + in.source = it->source; + outs.push_back(in); + } + } + auto conIt = connections.find(out.target); + if (conIt != connections.end() && conIt->second.source == target) { + conIt->second.source = it->source; + } + } + target = it->source; + } else { + it->target = target; + inputs[target].clear(); + inputs[target].push_back(*it); + connections[target] = *it; + } + break; + } + return target; +} + } // namespace dfcxx diff --git a/src/model/dfcxx/lib/dfcxx/io.cpp b/src/model/dfcxx/lib/dfcxx/io.cpp index 1be9b888..30adb489 100644 --- a/src/model/dfcxx/lib/dfcxx/io.cpp +++ b/src/model/dfcxx/lib/dfcxx/io.cpp @@ -34,6 +34,26 @@ DFVariable IO::inputScalar(const std::string &name, const DFType &type) { return var; } +DFVariable IO::newStream(const DFType &type) { + auto *var = meta.varBuilder.buildStream("", + IODirection::NONE, + meta, + type); + meta.storage.addVariable(var); + meta.graph.addNode(var, OpType::NONE, NodeData {}); + return var; +} + +DFVariable IO::newScalar(const DFType &type) { + auto *var = meta.varBuilder.buildScalar("", + IODirection::NONE, + meta, + type); + meta.storage.addVariable(var); + meta.graph.addNode(var, OpType::NONE, NodeData {}); + return var; +} + DFVariable IO::output(const std::string &name, const DFType &type) { auto *var = meta.varBuilder.buildStream(name, IODirection::OUTPUT, diff --git a/src/model/dfcxx/lib/dfcxx/kernel.cpp b/src/model/dfcxx/lib/dfcxx/kernel.cpp index 3cb639c5..204d05df 100644 --- a/src/model/dfcxx/lib/dfcxx/kernel.cpp +++ b/src/model/dfcxx/lib/dfcxx/kernel.cpp @@ -147,6 +147,34 @@ bool Kernel::compileDot(llvm::raw_fd_ostream *stream) { return true; } +void Kernel::rebindInput(DFVariable source, Node input, Kernel &kern) { + auto sourceNode = meta.graph.findNode(source); + meta.graph.rebindInput(sourceNode, + input, + kern.meta.graph); + + kern.deleteNode(input); +} + +DFVariable Kernel::rebindOutput(Node output, DFVariable target, Kernel &kern) { + auto targetNode = meta.graph.findNode(target); + auto node = meta.graph.rebindOutput(output, + targetNode, + kern.meta.graph); + + if (targetNode != node) { + deleteNode(targetNode); + } + + kern.deleteNode(output); + return node.var; +} + +void Kernel::deleteNode(Node node) { + meta.graph.deleteNode(node); + meta.storage.deleteVariable(node.var); +} + bool Kernel::compile(const DFLatencyConfig &config, const std::vector &outputPaths, const Scheduler &sched) { @@ -204,4 +232,30 @@ bool Kernel::simulate(const std::string &inDataPath, return sim.simulate(input, output); } +bool Kernel::check() const { + const auto &nodes = meta.graph.getNodes(); + const auto &startNodes = meta.graph.getStartNodes(); + const auto &connections = meta.graph.getConnections(); + std::cout << "[UTOPIA] Kernel: " << getName() << std::endl; + std::cout << "[UTOPIA] Nodes: " << nodes.size() << std::endl; + std::cout << "[UTOPIA] Start nodes: " << startNodes.size() << std::endl; + std::cout << "[UTOPIA] Connections: " << connections.size() << std::endl; + + return checkValidNodes(); +} + +bool Kernel::checkValidNodes() const { + std::cout << "[UTOPIA] Checking whether constructed nodes are valid: "; + + const auto &nodes = meta.graph.getNodes(); + for (const Node &node: nodes) { + if (node.type == OpType::NONE) { + std::cout << "found invalid node(s). Abort." << std::endl; + return false; + } + } + std::cout << "finished." << std::endl; + return true; +} + } // namespace dfcxx diff --git a/src/model/dfcxx/lib/dfcxx/kernmeta.cpp b/src/model/dfcxx/lib/dfcxx/kernmeta.cpp new file mode 100644 index 00000000..4d1335f9 --- /dev/null +++ b/src/model/dfcxx/lib/dfcxx/kernmeta.cpp @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// Part of the Utopia HLS Project, under the Apache License v2.0 +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2024 ISP RAS (http://www.ispras.ru) +// +//===----------------------------------------------------------------------===// + +#include "dfcxx/kernmeta.h" + +namespace dfcxx { + +void KernMeta::transferFrom(KernMeta &&meta) { + graph.transferFrom(std::move(meta.graph)); + storage.transferFrom(std::move(meta.storage)); +} + +} // namespace dfcxx diff --git a/src/model/dfcxx/lib/dfcxx/kernstorage.cpp b/src/model/dfcxx/lib/dfcxx/kernstorage.cpp index f94071bc..ae1eb4f0 100644 --- a/src/model/dfcxx/lib/dfcxx/kernstorage.cpp +++ b/src/model/dfcxx/lib/dfcxx/kernstorage.cpp @@ -29,6 +29,9 @@ DFVariableImpl *KernStorage::addVariable(DFVariableImpl *var) { return *(variables.insert(var).first); } +void KernStorage::deleteVariable(DFVariableImpl *var) { + variables.erase(var); +} KernStorage::~KernStorage() { for (DFTypeImpl *type: types) { @@ -39,4 +42,9 @@ KernStorage::~KernStorage() { } } +void KernStorage::transferFrom(KernStorage &&storage) { + types.merge(std::move(storage.types)); + variables.merge(std::move(storage.variables)); +} + } // namespace dfcxx diff --git a/src/model/dfcxx/lib/dfcxx/vars/var.cpp b/src/model/dfcxx/lib/dfcxx/vars/var.cpp index e413ef51..62164200 100644 --- a/src/model/dfcxx/lib/dfcxx/vars/var.cpp +++ b/src/model/dfcxx/lib/dfcxx/vars/var.cpp @@ -21,6 +21,10 @@ std::string_view DFVariableImpl::getName() const { return name; } +void DFVariableImpl::resetName() { + name.clear(); +} + DFVariableImpl::IODirection DFVariableImpl::getDirection() const { return direction; }