Skip to content

Commit

Permalink
[CANN]: Update the doc of CANN EP (#23087)
Browse files Browse the repository at this point in the history
### Description
This PR update docs of Ascend CANN execution provider.
  • Loading branch information
bachelor-dou authored Dec 18, 2024
1 parent 9c0b71e commit 7ad9caa
Showing 1 changed file with 146 additions and 24 deletions.
170 changes: 146 additions & 24 deletions docs/execution-providers/community-maintained/CANN-ExecutionProvider.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ Please reference table below for official CANN packages dependencies for the ONN

|ONNX Runtime|CANN|
|---|---|
|v1.12.1|6.0.0|
|v1.13.1|6.0.0|
|v1.14.0|6.0.0|
|v1.15.0|6.0.0|
|v1.18.0|8.0.0|
|v1.19.0|8.0.0|
|v1.20.0|8.0.0|

## Build

Expand Down Expand Up @@ -201,26 +200,149 @@ session = ort.InferenceSession(model_path, sess_options=options, providers=provi
```

### C/C++

```c
const static OrtApi *g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);

OrtSessionOptions *session_options;
g_ort->CreateSessionOptions(&session_options);

OrtCANNProviderOptions *cann_options = nullptr;
g_ort->CreateCANNProviderOptions(&cann_options);

std::vector<const char *> keys{"device_id", "npu_mem_limit", "arena_extend_strategy", "enable_cann_graph"};
std::vector<const char *> values{"0", "2147483648", "kSameAsRequested", "1"};

g_ort->UpdateCANNProviderOptions(cann_options, keys.data(), values.data(), keys.size());

g_ort->SessionOptionsAppendExecutionProvider_CANN(session_options, cann_options);

// Finally, don't forget to release the provider options and session options
g_ort->ReleaseCANNProviderOptions(cann_options);
g_ort->ReleaseSessionOptions(session_options);
Note: This sample shows model inference using [resnet50_Opset16.onnx](https://github.com/onnx/models/tree/main/Computer_Vision/resnet50_Opset16_timm) as an example.
You need to modify the model_path, and the input_prepare() and output_postprocess() functions according to your needs.


```c++
#include <iostream>
#include <vector>

#include "onnxruntime_cxx_api.h"

// path of model, Change to user's own model path
const char* model_path = "./onnx/resnet50_Opset16.onnx";

/**
* @brief Input data preparation provided by user.
*
* @param num_input_nodes The number of model input nodes.
* @return A collection of input data.
*/
std::vector<std::vector<float>> input_prepare(size_t num_input_nodes) {
std::vector<std::vector<float>> input_datas;
input_datas.reserve(num_input_nodes);

constexpr size_t input_data_size = 3 * 224 * 224;
std::vector<float> input_data(input_data_size);
// initialize input data with values in [0.0, 1.0]
for (unsigned int i = 0; i < input_data_size; i++)
input_data[i] = (float)i / (input_data_size + 1);
input_datas.push_back(input_data);

return input_datas;
}

/**
* @brief Model output data processing logic(For User updates).
*
* @param output_tensors The results of the model output.
*/
void output_postprocess(std::vector<Ort::Value>& output_tensors) {
auto floatarr = output_tensors.front().GetTensorMutableData<float>();

for (int i = 0; i < 5; i++) {
std::cout << "Score for class [" << i << "] = " << floatarr[i] << '\n';
}

std::cout << "Done!" << std::endl;
}

/**
* @brief The main functions for model inference.
*
* The complete model inference process, which generally does not need to be
* changed here
*/
void inference() {
const auto& api = Ort::GetApi();

// Enable cann graph in cann provider option.
OrtCANNProviderOptions* cann_options = nullptr;
api.CreateCANNProviderOptions(&cann_options);

// Configurations of EP
std::vector<const char*> keys{
"device_id",
"npu_mem_limit",
"arena_extend_strategy",
"enable_cann_graph"};
std::vector<const char*> values{"0", "4294967296", "kNextPowerOfTwo", "1"};
api.UpdateCANNProviderOptions(
cann_options, keys.data(), values.data(), keys.size());

// Convert to general session options
Ort::SessionOptions session_options;
api.SessionOptionsAppendExecutionProvider_CANN(
static_cast<OrtSessionOptions*>(session_options), cann_options);

Ort::Session session(Ort::Env(), model_path, session_options);

Ort::AllocatorWithDefaultOptions allocator;

// Input Process
const size_t num_input_nodes = session.GetInputCount();
std::vector<const char*> input_node_names;
std::vector<Ort::AllocatedStringPtr> input_names_ptr;
input_node_names.reserve(num_input_nodes);
input_names_ptr.reserve(num_input_nodes);
std::vector<std::vector<int64_t>> input_node_shapes;
std::cout << num_input_nodes << std::endl;
for (size_t i = 0; i < num_input_nodes; i++) {
auto input_name = session.GetInputNameAllocated(i, allocator);
input_node_names.push_back(input_name.get());
input_names_ptr.push_back(std::move(input_name));
auto type_info = session.GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
input_node_shapes.push_back(tensor_info.GetShape());
}

// Output Process
const size_t num_output_nodes = session.GetOutputCount();
std::vector<const char*> output_node_names;
std::vector<Ort::AllocatedStringPtr> output_names_ptr;
output_names_ptr.reserve(num_input_nodes);
output_node_names.reserve(num_output_nodes);
for (size_t i = 0; i < num_output_nodes; i++) {
auto output_name = session.GetOutputNameAllocated(i, allocator);
output_node_names.push_back(output_name.get());
output_names_ptr.push_back(std::move(output_name));
}

// User need to generate input date according to real situation.
std::vector<std::vector<float>> input_datas = input_prepare(num_input_nodes);

auto memory_info = Ort::MemoryInfo::CreateCpu(
OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);

std::vector<Ort::Value> input_tensors;
input_tensors.reserve(num_input_nodes);
for (size_t i = 0; i < input_node_shapes.size(); i++) {
auto input_tensor = Ort::Value::CreateTensor<float>(
memory_info,
input_datas[i].data(),
input_datas[i].size(),
input_node_shapes[i].data(),
input_node_shapes[i].size());
input_tensors.push_back(std::move(input_tensor));
}

auto output_tensors = session.Run(
Ort::RunOptions{nullptr},
input_node_names.data(),
input_tensors.data(),
num_input_nodes,
output_node_names.data(),
output_node_names.size());

// Processing of out_tensor
output_postprocess(output_tensors);
}

int main(int argc, char* argv[]) {
inference();
return 0;
}
```
## Supported ops
Expand Down

0 comments on commit 7ad9caa

Please sign in to comment.