Skip to content

Commit

Permalink
make shape/stride correct when component != 1
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2024
1 parent 43ccaf4 commit eae4c3f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 32 deletions.
51 changes: 51 additions & 0 deletions onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,57 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp
}
}

namespace {
TensorShape GetReducedShape(const TensorShape& shape, int component /* > 1 */) {
ORT_ENFORCE(shape.NumDimensions() > 0 && shape.GetDims()[shape.NumDimensions() - 1] % component == 0,
"Cannot reduce shape ", shape.ToString(), " by component=", component);
TensorShape reduced_shape = shape;
reduced_shape[reduced_shape.NumDimensions() - 1] /= component;
return reduced_shape;
}
} // namespace

ProgramInput::ProgramInput(const Tensor* tensor) : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {}

ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{component > 1},
override_shape{} {
if (use_override_shape) {
override_shape = GetReducedShape(tensor->Shape(), component);
}
}

ProgramInput::ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{true},
override_shape{override_shape} {}

ProgramOutput::ProgramOutput(Tensor* tensor)
: ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {}

ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{component > 1},
override_shape{} {
if (use_override_shape) {
override_shape = GetReducedShape(tensor->Shape(), component);
}
}

ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{true},
override_shape{override_shape} {}

ProgramBase::ProgramBase(const std::string& name)
: name_{name},
dispatch_group_size_x_{0},
Expand Down
34 changes: 6 additions & 28 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,9 @@ int NumberOfComponents(ProgramVariableDataType type);
ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1);

struct ProgramInput {
ProgramInput(const Tensor* tensor)
: ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {}
ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{false},
override_shape{} {}
ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{true},
override_shape{override_shape} {}
ProgramInput(const Tensor* tensor);
ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1);
ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component);

const Tensor* tensor;
ProgramTensorMetadataDependency dependency;
Expand All @@ -231,20 +220,9 @@ struct ProgramInput {
};

struct ProgramOutput {
ProgramOutput(Tensor* tensor)
: ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {}
ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{false},
override_shape{} {}
ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component)
: tensor{tensor},
dependency{dependency},
var_type{ToProgramVariableDataType(tensor->GetElementType(), component)},
use_override_shape{true},
override_shape{override_shape} {}
ProgramOutput(Tensor* tensor);
ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1);
ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component);

Tensor* tensor;
ProgramTensorMetadataDependency dependency;
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ Status ValidateVariableShape(const TensorShape& origin_shape,
// if override shape specified, assert override_size == ceil( origin_size / 4 )
ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(),
"Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components);
} else if (num_components > 1) {
// if shape is not overriden, assert origin_shape[-1] % 4 == 0
ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.NumDimensions() - 1] % num_components == 0,
"Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components);
}

return Status::OK();
Expand Down

0 comments on commit eae4c3f

Please sign in to comment.