Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/core/graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/op_type.h"
#include "operators/ElementWise.h"
#include "operators/Gemm.h"
#include "operators/Clip.h"

namespace infini {

Expand All @@ -25,6 +26,7 @@ class GraphBuilderObj {
Tensor add(Tensor A, Tensor B, std::optional<Tensor> Y = std::nullopt);
Tensor sub(Tensor A, Tensor B, std::optional<Tensor> Y = std::nullopt);
Tensor mul(Tensor A, Tensor B, std::optional<Tensor> Y = std::nullopt);
Tensor clip(Tensor A, Tensor min_val, Tensor max_val, std::optional<Tensor> Y = std::nullopt);
string printGraph() const;

Graph getGraph() const;
Expand Down
28 changes: 28 additions & 0 deletions include/operators/Clip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once
#include "core/graph.h"
#include "core/operator.h"

#include <infiniop/ops/clip.h>


namespace infini {
class ClipObj : public OperatorObj {
public:
/**
* @brief Construct a new Clip object
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param min_val The minimum value tensor for clipping.
* @param max_val The maximum value tensor for clipping.
* @param output The output tensor.
*/
ClipObj(GraphObj *graph, Tensor input, Tensor min_val, Tensor max_val, Tensor output);
string toString() const override;
~ClipObj() override;

void createOpDesc() override;
optional<vector<ShapeExpr>> inferShape() override;
vector<DataType> inferDataType() const override;
};
} // namespace infini
2 changes: 2 additions & 0 deletions python/bindings/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void bind_graph_builder(py::module &m) {
py::arg("Y") = py::none())
.def("mul", &GraphBuilderObj::mul, py::arg("A"), py::arg("B"),
py::arg("Y") = py::none())
.def("clip", &GraphBuilderObj::clip, py::arg("A"), py::arg("min_val"), py::arg("max_val"),
py::arg("Y") = py::none())
.def("to_string", &GraphBuilderObj::printGraph)
.def_property_readonly("graph", &GraphBuilderObj::getGraph);
}
Expand Down
9 changes: 8 additions & 1 deletion python/src/infinitensor/converter/unified_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,11 @@ def convert_add(translator, node):
def convert_add(translator, node):
a = translator.tensors[node.args[0]]
b = translator.tensors[node.args[1]]
translator.tensors[node] = translator.builder.add(a, b, None)
translator.tensors[node] = translator.builder.sub(a, b, None)

@registry.register("clip","Tensor")
def convert_clip_tensor(translator, node):
a = translator.tensors[node.args[0]]
min_val = translator.tensors[node.args[1]]
max_val = translator.tensors[node.args[2]]
translator.tensors[node] = translator.builder.clip(a, min_val, max_val, None)
33 changes: 33 additions & 0 deletions python/tests/test_torch_fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ def forward(self, x, y):
print("✅ Test passed!")


def test_clip(runtime, torch_rng_seed):
"""Use fixtures defined in conftest.py directly"""
print(f"Testing with runtime on device: {runtime}")
print(f"Random seed: {torch_rng_seed}")

# Create simple model
class ClipModel(torch.nn.Module):
def forward(self, x, min_val, max_val):
return torch.clip(x, min=min_val, max=max_val)

model = ClipModel()

# Randomly initialize inputs, passed shapes can differ from actual values, but data types must match
input_info = [((5, 4), "float32"), ((5, 4), "float32"), ((5, 4), "float32")]
input_tensors = [
torch.as_tensor(np.random.randn(*shape).astype(dtype))
for shape, dtype in input_info
]

# Create translator
translator = TorchFXTranslator(runtime)
translator.import_from_fx(model, input_tensors)

translator.run(input_tensors)
# Get outputs
outputs = translator.get_outputs()

# Verify
assert len(outputs) == 1
assert outputs[0].shape == (5, 4)
print("✅ Test passed!")


if __name__ == "__main__":
# Can run this file directly
import sys
Expand Down
11 changes: 11 additions & 0 deletions src/core/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ Tensor GraphBuilderObj::gemm(Tensor A, Tensor B, Tensor C, float alpha,
} \
}

Tensor GraphBuilderObj::clip(Tensor A, Tensor min_val, Tensor max_val, std::optional<Tensor> Y) {
if (Y.has_value()) {
g->addOpWithOutputs<ClipObj>(std::move(A), std::move(min_val), std::move(max_val), std::move(Y.value()));
return Y.value();
} else {
return g
->addOp<ClipObj>(std::move(A), std::move(min_val), std::move(max_val), nullptr)
->getOutput(0);
}
}

DEFINE_BINARY_OP(add, OpType::Add);
DEFINE_BINARY_OP(sub, OpType::Sub);
DEFINE_BINARY_OP(mul, OpType::Mul);
Expand Down
27 changes: 27 additions & 0 deletions src/kernels/Clip.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "core/runtime.h"
#include "operators/Clip.h"

namespace infini {

class ClipOp : public Kernel {
void compute(const Operator &_op,
const RuntimeObj *runtime) const override {
auto op = as<ClipObj>(_op);
op->createOpDesc();
void *yData = (op->getOutput(0)->getRawDataPtr<void *>());
void *const aData = (op->getInput(0)->getRawDataPtr<void *>());
void *const min_val = (op->getInput(1)->getRawDataPtr<void *>());
void *const max_val = (op->getInput(2)->getRawDataPtr<void *>());
size_t workspace_size = 0;
CHECK_INFINI_ERROR(infiniopGetClipWorkspaceSize(
(infiniopClipDescriptor_t)op->getInfiniOpDesc(), &workspace_size));
void *workspace = runtime->getWorkspace(workspace_size);
CHECK_INFINI_ERROR(infiniopClip(
(infiniopClipDescriptor_t)op->getInfiniOpDesc(), workspace,
workspace_size, yData, aData, min_val, max_val,
runtime->getCurrentThreadContext()->stream));
}
};
// 执行注册机制,将算子和对应的计算方式进行绑定并添加到对应的注册表中
REGISTER_KERNEL_ALL_DEVICES(OpType::Clip, ClipOp);
} // namespace infini
92 changes: 92 additions & 0 deletions src/operators/Clip.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "operators/Clip.h"
#include "core/runtime.h"

namespace infini {

ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor min_val, Tensor max_val,
Tensor output)
: OperatorObj(OpType::Clip, TensorVec{input, min_val, max_val}, {output}) {
IT_ASSERT(checkValid(graph));
}

string ClipObj::toString() const {
std::ostringstream os;
os << "Clip(";
os << "input=" << inputs[0]->getGuid() << ",";
os << "min_val=" << inputs[1]->getGuid() << ",";
os << "max_val=" << inputs[2]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}

ClipObj::~ClipObj() {
if (infiniOpDesc) {
infiniStatus_t err = INFINI_STATUS_SUCCESS;
err = infiniopDestroyClipDescriptor((infiniopClipDescriptor_t)infiniOpDesc);
if (err != INFINI_STATUS_SUCCESS) {
std::cerr << "Warning: Clip descriptor destroy failed with error code "
<< err << std::endl;
}
}
}

optional<vector<ShapeExpr>> ClipObj::inferShape() {
// Clip does not change the shape of the input tensor
// Simply return the input shape as-is (supports both concrete and symbolic shapes)
auto inputShape = inputs[0]->getShape();
return {{inputShape}};
}

vector<DataType> ClipObj::inferDataType() const {
return {inputs[0]->getDataType()};
}

void ClipObj::createOpDesc() {
auto yShape = outputs[0]->getShape();
auto yStride = outputs[0]->getStride();

auto xShape = inputs[0]->getShape();
auto xStride = inputs[0]->getStride();

auto minValShape = inputs[1]->getShape();
auto minValStride = inputs[1]->getStride();

auto maxValShape = inputs[2]->getShape();
auto maxValStride = inputs[2]->getStride();

infiniopTensorDescriptor_t yTensor, xTensor, minValTensor, maxValTensor;

CHECK_INFINI_ERROR(infiniopCreateTensorDescriptor(
&yTensor, yShape->size(), yShape->getConstantValue().data(),
yStride->getConstantValue().data(),
outputs[0]->getDataType().getType()));

CHECK_INFINI_ERROR(infiniopCreateTensorDescriptor(
&xTensor, xShape->size(), xShape->getConstantValue().data(),
xStride->getConstantValue().data(),
inputs[0]->getDataType().getType()));

CHECK_INFINI_ERROR(infiniopCreateTensorDescriptor(
&minValTensor, minValShape->size(), minValShape->getConstantValue().data(),
minValStride->getConstantValue().data(),
inputs[1]->getDataType().getType()));

CHECK_INFINI_ERROR(infiniopCreateTensorDescriptor(
&maxValTensor, maxValShape->size(), maxValShape->getConstantValue().data(),
maxValStride->getConstantValue().data(),
inputs[2]->getDataType().getType()));

infiniopHandle_t handle = nullptr;
CHECK_INFINI_ERROR(infiniopCreateHandle(&handle));

CHECK_INFINI_ERROR(infiniopCreateClipDescriptor(
handle, (infiniopClipDescriptor_t *)&infiniOpDesc, yTensor, xTensor,
minValTensor, maxValTensor));

CHECK_INFINI_ERROR(infiniopDestroyTensorDescriptor(yTensor));
CHECK_INFINI_ERROR(infiniopDestroyTensorDescriptor(xTensor));
CHECK_INFINI_ERROR(infiniopDestroyTensorDescriptor(minValTensor));
CHECK_INFINI_ERROR(infiniopDestroyTensorDescriptor(maxValTensor));
}

} // namespace infini
Loading