Skip to content

Commit

Permalink
allow setting "ValidationMode"
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 12, 2024
1 parent 8978d89 commit 43ccaf4
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 125 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ struct ProgramOutput {
TensorShape override_shape;
};

enum class ValidationMode {
Disabled = 0,
WGPUOnly,
Basic,
Full
};

namespace detail {
class ProgramWrapper;
}
Expand Down
242 changes: 124 additions & 118 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,79 +17,6 @@
namespace onnxruntime {
namespace webgpu {

namespace {

std::vector<const char*> GetEnabledAdapterToggles() {
// See the description of all the toggles in toggles.cpp
// "use_dxc" for Shader Model 6+ features (e.g. float16)
// "allow_unsafe_apis" for chromium experimental features
constexpr const char* toggles[] = {
"use_dxc",
"allow_unsafe_apis",
};
return std::vector<const char*>(std::begin(toggles), std::end(toggles));
}

std::vector<const char*> GetEnabledDeviceToggles() {
// Enable / disable other toggles that may affect the performance.
// Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming"
constexpr const char* toggles[] = {
#ifdef NDEBUG
// todo: when skip validation, the process may crash.
// need careful decision to enable this toggle.
// revisit this flag before release.
"skip_validation",
#endif
"disable_robustness",
"disable_workgroup_init",
"d3d_disable_ieee_strictness",
};
return std::vector<const char*>(std::begin(toggles), std::end(toggles));
}

std::vector<const char*> GetDisabledDeviceToggles() {
constexpr const char* toggles[] = {
"lazy_clear_resource_on_first_use",
};
return std::vector<const char*>(std::begin(toggles), std::end(toggles));
}

std::vector<wgpu::FeatureName> GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) {
std::vector<wgpu::FeatureName> required_features;
constexpr wgpu::FeatureName features[]{
wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses,
wgpu::FeatureName::TimestampQuery,
wgpu::FeatureName::ShaderF16,
wgpu::FeatureName::Subgroups,
wgpu::FeatureName::SubgroupsF16};
for (auto feature : features) {
if (adapter.HasFeature(feature)) {
required_features.push_back(feature);
}
}
return required_features;
}

wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) {
wgpu::RequiredLimits required_limits{};
wgpu::SupportedLimits adapter_limits;
ORT_ENFORCE(adapter.GetLimits(&adapter_limits));

required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups;
required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize;
required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension;
required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize;
required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize;
required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup;
required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX;
required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY;
required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ;

return required_limits;
}

} // namespace

void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) {
std::call_once(init_flag_, [this, &webgpu_ep_info]() {
// Initialization.Step.1 - Create wgpu::Instance
Expand Down Expand Up @@ -194,67 +121,73 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog
const auto& inputs = program.Inputs();
const auto& outputs = program.Outputs();

#ifndef NDEBUG // if debug build
ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) {
const auto* tensor = input.tensor;
return tensor != nullptr &&
tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault &&
tensor->Location().device.Type() == OrtDevice::GPU &&
!strcmp(tensor->Location().name, WEBGPU_BUFFER);
}),
"All inputs must be tensors on WebGPU buffers.");

ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) {
const auto* tensor = output.tensor;
return tensor != nullptr &&
tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault &&
tensor->Location().device.Type() == OrtDevice::GPU &&
!strcmp(tensor->Location().name, WEBGPU_BUFFER);
}),
"All outputs must be tensors on WebGPU buffers.");
#endif

if (outputs.size() == 0) {
return Status::OK();
}

if (ValidationMode() >= ValidationMode::Basic) {
ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) {
const auto* tensor = input.tensor;
return tensor != nullptr &&
tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault &&
tensor->Location().device.Type() == OrtDevice::GPU &&
!strcmp(tensor->Location().name, WEBGPU_BUFFER);
}),
"All inputs must be tensors on WebGPU buffers.");

ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) {
const auto* tensor = output.tensor;
return tensor != nullptr &&
tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault &&
tensor->Location().device.Type() == OrtDevice::GPU &&
!strcmp(tensor->Location().name, WEBGPU_BUFFER);
}),
"All outputs must be tensors on WebGPU buffers.");
}

const ProgramMetadata metadata = program.GetMetadata();

// validate program metadata
{
if (ValidationMode() >= ValidationMode::Basic) {
const auto& [constants, overridable_constants, uniform_variables] = metadata;

// check overridable constants
ORT_RETURN_IF(program.OverridableConstants().size() != overridable_constants.size(),
"Size of overridable constants mismatch in program \"", program.Name(),
"\", Expected: ", overridable_constants.size(),
", Actual: ", program.OverridableConstants().size());
size_t num_overridable_constants = program.OverridableConstants().size();
for (size_t i = 0; i < num_overridable_constants; ++i) {
const auto& override_value = program.OverridableConstants()[i];
const auto& definition = overridable_constants[i];
ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type,
"Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(),
"\", Expected: ", definition.type,
", Actual: ", override_value.type);
ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value,
"Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(),
"\"");

if (ValidationMode() >= ValidationMode::Full) {
size_t num_overridable_constants = program.OverridableConstants().size();
for (size_t i = 0; i < num_overridable_constants; ++i) {
const auto& override_value = program.OverridableConstants()[i];
const auto& definition = overridable_constants[i];
ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type,
"Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(),
"\", Expected: ", definition.type,
", Actual: ", override_value.type);
ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value,
"Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(),
"\"");
}
}

// check uniform variables
ORT_RETURN_IF(program.UniformVariables().size() != uniform_variables.size(),
"Size of uniform_value variables mismatch in program \"", program.Name(),
"\", Expected: ", uniform_variables.size(),
", Actual: ", program.UniformVariables().size());
size_t num_uniform_variables = program.UniformVariables().size();
for (size_t i = 0; i < num_uniform_variables; ++i) {
const auto& uniform_value = program.UniformVariables()[i];
const auto& definition = uniform_variables[i];
ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type,
"Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(),
"\", Expected: ", definition.data_type,
", Actual: ", uniform_value.data_type);

if (ValidationMode() >= ValidationMode::Full) {
size_t num_uniform_variables = program.UniformVariables().size();
for (size_t i = 0; i < num_uniform_variables; ++i) {
const auto& uniform_value = program.UniformVariables()[i];
const auto& definition = uniform_variables[i];
ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type,
"Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(),
"\", Expected: ", definition.data_type,
", Actual: ", uniform_value.data_type);
}
}
}

Expand Down Expand Up @@ -295,9 +228,11 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog
// prepare shape uniforms for shader variables (if any) and user defined uniforms
std::vector<ProgramUniformVariableValue> shape_uniforms;
shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2);
ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(),
"Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(),
") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")");
if (ValidationMode() >= ValidationMode::Basic) {
ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(),
"Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(),
") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")");
}
for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) {
SafeInt<int> expected_rank = program_artifact->shape_uniform_ranks[i];
if (expected_rank > 0) {
Expand Down Expand Up @@ -423,10 +358,81 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog
return Status::OK();
}

std::vector<const char*> WebGpuContext::GetEnabledAdapterToggles() const {
// See the description of all the toggles in toggles.cpp
// "use_dxc" for Shader Model 6+ features (e.g. float16)
// "allow_unsafe_apis" for chromium experimental features
constexpr const char* toggles[] = {
"use_dxc",
"allow_unsafe_apis",
};
return std::vector<const char*>(std::begin(toggles), std::end(toggles));
}

std::vector<const char*> WebGpuContext::GetEnabledDeviceToggles() const {
// Enable / disable other toggles that may affect the performance.
// Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming"
constexpr const char* toggles[] = {
"skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled"
"disable_robustness",
"disable_workgroup_init",
"d3d_disable_ieee_strictness",
};
return std::vector<const char*>(ValidationMode() >= ValidationMode::WGPUOnly
? std::begin(toggles) + 1
: std::begin(toggles),
std::end(toggles));
}

std::vector<const char*> WebGpuContext::GetDisabledDeviceToggles() const {
constexpr const char* toggles[] = {
"lazy_clear_resource_on_first_use",
};
return std::vector<const char*>(std::begin(toggles), std::end(toggles));
}

std::vector<wgpu::FeatureName> WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const {
std::vector<wgpu::FeatureName> required_features;
constexpr wgpu::FeatureName features[]{
wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses,
wgpu::FeatureName::TimestampQuery,
wgpu::FeatureName::ShaderF16,
wgpu::FeatureName::Subgroups,
wgpu::FeatureName::SubgroupsF16};
for (auto feature : features) {
if (adapter.HasFeature(feature)) {
required_features.push_back(feature);
}
}
return required_features;
}

wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const {
wgpu::RequiredLimits required_limits{};
wgpu::SupportedLimits adapter_limits;
ORT_ENFORCE(adapter.GetLimits(&adapter_limits));

required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups;
required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize;
required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension;
required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize;
required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize;
required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup;
required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX;
required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY;
required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ;

return required_limits;
}

std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> WebGpuContextFactory::contexts_;
OrtMutex WebGpuContextFactory::mutex_;

WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) {
WebGpuContext& WebGpuContextFactory::CreateContext(int context_id,
WGPUInstance instance,
WGPUAdapter adapter,
WGPUDevice device,
ValidationMode validation_mode) {
if (context_id == 0) {
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
Expand All @@ -441,7 +447,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance

auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device));
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, validation_mode));
it = contexts_.emplace(context_id, std::move(context)).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device,
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ class ProgramBase;

class WebGpuContextFactory {
public:
static WebGpuContext& CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device);
static WebGpuContext& CreateContext(int context_id,
WGPUInstance instance,
WGPUAdapter adapter,
WGPUDevice device,
ValidationMode validation_mode);
static WebGpuContext& GetContext(int context_id);

private:
Expand Down Expand Up @@ -95,18 +99,31 @@ class WebGpuContext final {

webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; }

inline webgpu::ValidationMode ValidationMode() const {
return validation_mode_;
}

Status Run(const ComputeContext& context, const ProgramBase& program);

private:
WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) : instance_{instance}, adapter_{adapter}, device_{device} {}
WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode} {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);

std::vector<const char*> WebGpuContext::GetEnabledAdapterToggles() const;
std::vector<const char*> WebGpuContext::GetEnabledDeviceToggles() const;
std::vector<const char*> WebGpuContext::GetDisabledDeviceToggles() const;
std::vector<wgpu::FeatureName> WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const;
wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const;

std::once_flag init_flag_;

wgpu::Instance instance_;
wgpu::Adapter adapter_;
wgpu::Device device_;

webgpu::ValidationMode validation_mode_;

wgpu::AdapterInfo adapter_info_;
wgpu::Limits device_limits_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#include "core/framework/kernel_registry.h"
#include "core/graph/function_utils.h"
#include "core/graph/indexed_sub_graph.h"
#include "data_transfer.h"

#include "core/providers/webgpu/data_transfer.h"

namespace onnxruntime {

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ enum class BufferCacheMode;
} // namespace webgpu

struct WebGpuExecutionProviderInfo {
WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1)
: data_layout{data_layout1},
enable_graph_capture{enable_graph_capture1},
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture},
storage_buffer_cache_mode{},
uniform_buffer_cache_mode{},
query_resolve_buffer_cache_mode{},
Expand Down
Loading

0 comments on commit 43ccaf4

Please sign in to comment.