Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/gpu/intel/eltwise/ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "gpu/intel/include/dispatch.h"
#include "gpu/intel/include/eltwise.h"
#include "gpu/intel/include/io.h"
#include "gpu/intel/include/philox.h"
#include "gpu/intel/include/post_ops.h"
#include "gpu/intel/include/types_interop.h"

Expand All @@ -29,8 +30,19 @@

#if IS_FWD
__kernel void ref_eltwise_fwd(__global SRC_DATA_T *src,
__global DST_DATA_T *dst, float alpha, float beta,
int64x3_t offset POST_OP_ARGS) {
__global DST_DATA_T *dst, float alpha, float beta, int64x3_t offset
#if WITH_DROPOUT
,
__global uchar *dropout_mask_buf,
#if USE_HOST_SCALARS
long dropout_seed, long dropout_offset,
float dropout_p
#else
__global long *dropout_seed_buf, __global long *dropout_offset_buf,
__global float *dropout_p_buf
#endif
#endif
POST_OP_ARGS) {
#if USE_GWS_GET
dim_t d0 = GWS_GET_D0();
dim_t d1 = GWS_GET_D1();
Expand Down Expand Up @@ -77,6 +89,28 @@ __kernel void ref_eltwise_fwd(__global SRC_DATA_T *src,
#endif

APPLY_POST_OPS_SERIAL(tmp_s, dst_data, d0, d1, d2, d3, d4, d5);

#if WITH_DROPOUT
#if !USE_HOST_SCALARS
long dropout_seed = dropout_seed_buf[0];
long dropout_offset = USE_OFFSET ? dropout_offset_buf[0] : 0;
float dropout_p = dropout_p_buf[0];
#endif
uint dropout_threshold = get_dropout_threshold(dropout_p);
float dropout_inv_q = (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f;
#if USE_OFFSET
uint res = philox_4x32_s64(
(ulong)data_off, (ulong)dropout_seed, (ulong)dropout_offset);
#else
uint res = philox_4x32(data_off, (uint)dropout_seed);
#endif
uchar dropout = res > dropout_threshold;
tmp_s = (dropout) ? tmp_s * dropout_inv_q : 0;
#if HAS_OUTPUT_MASK
dropout_mask_buf[data_off] = dropout;
#endif
#endif

write(dst + data_off, tmp_s);
}

Expand Down
59 changes: 53 additions & 6 deletions src/gpu/intel/eltwise/ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ static status_t init_conf_common(

static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
const conf_t &conf, const post_ops_t &post_ops,
const memory_desc_t *dst_md) {
const dropout_t &dropout, const memory_desc_t *dst_md) {
kernel_ctx.set_data_type(conf.data_type);
kernel_ctx.require_stateless_addressing(conf.require_stateless_addressing);

Expand All @@ -76,6 +76,10 @@ static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
kernel_ctx.define_int("GWS1", conf.dispatch.nd_range().global_range()[1]);
kernel_ctx.define_int("GWS2", conf.dispatch.nd_range().global_range()[2]);
kernel_ctx.define_int("USE_CUSTOM_GWS_GET_ID", 1);
kernel_ctx.define_int("WITH_DROPOUT", !dropout.has_default_values());
kernel_ctx.define_int("USE_HOST_SCALARS", dropout.use_host_scalars_);
kernel_ctx.define_int("USE_OFFSET", dropout.use_offset_);
kernel_ctx.define_int("HAS_OUTPUT_MASK", dropout.has_output_mask());

bool with_binary_post_ops
= post_ops.find(primitive_kind_t::dnnl_binary) != -1;
Expand Down Expand Up @@ -103,8 +107,8 @@ status_t ref_fwd_t::pd_t::init_conf(impl::engine_t *engine) {

status_t ref_fwd_t::pd_t::init_kernel_ctx(
compute::kernel_ctx_t &kernel_ctx) const {
return init_kernel_ctx_common(
kernel_ctx, conf, attr()->post_ops_, invariant_dst_md());
return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_,
attr()->dropout_, invariant_dst_md());
}

status_t ref_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
Expand All @@ -122,9 +126,52 @@ status_t ref_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const {
arg_list.set(1, dst);
arg_list.set(2, alpha);
arg_list.set(3, beta);
const bool with_dropout = !pd()->attr()->dropout_.has_default_values();

int arg_idx = 5;
if (with_dropout) {
const bool use_host_scalars = pd()->attr()->dropout_.use_host_scalars_;
const bool use_offset = pd()->attr()->dropout_.use_offset_;

const auto &dropout_p
= CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_PROBABILITY);
const auto &dropout_seed = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_SEED);
const auto &dropout_offset
= CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_OFFSET);
arg_list.set(arg_idx++, CTX_OUT_STORAGE(DNNL_ARG_ATTR_DROPOUT_MASK));
if (use_host_scalars) {
int64_t scalar_seed = 0;
int64_t scalar_offset = 0;
float scalar_prob = 0.f;
const host_scalar_memory_storage_t *seed_storage
= utils::downcast<const host_scalar_memory_storage_t *>(
&dropout_seed);
CHECK(seed_storage->get_scalar_value(
&scalar_seed, sizeof(scalar_seed)));
if (use_offset) {
const host_scalar_memory_storage_t *offset_storage
= utils::downcast<const host_scalar_memory_storage_t *>(
&dropout_offset);
CHECK(offset_storage->get_scalar_value(
&scalar_offset, sizeof(scalar_offset)));
}
const host_scalar_memory_storage_t *prob_storage
= utils::downcast<const host_scalar_memory_storage_t *>(
&dropout_p);
CHECK(prob_storage->get_scalar_value(
&scalar_prob, sizeof(scalar_prob)));
arg_list.set(arg_idx++, scalar_seed);
arg_list.set(arg_idx++, scalar_offset);
arg_list.set(arg_idx++, scalar_prob);
} else {
arg_list.set(arg_idx++, dropout_seed);
arg_list.set(arg_idx++, dropout_offset);
arg_list.set(arg_idx++, dropout_p);
}
}

append_post_ops_to_arg_list(
ctx, arg_list, 5, pd()->attr()->post_ops_, *pd()->dst_md());
ctx, arg_list, arg_idx, pd()->attr()->post_ops_, *pd()->dst_md());

auto nd_range = conf.dispatch.nd_range();
return large_parallel_for(ctx, nd_range, kernel_, arg_list, 4);
Expand All @@ -136,8 +183,8 @@ status_t ref_bwd_t::pd_t::init_conf(impl::engine_t *engine) {

status_t ref_bwd_t::pd_t::init_kernel_ctx(
compute::kernel_ctx_t &kernel_ctx) const {
return init_kernel_ctx_common(
kernel_ctx, conf, attr()->post_ops_, invariant_dst_md());
return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_,
attr()->dropout_, invariant_dst_md());
}

status_t ref_bwd_t::execute_backward_dense(const exec_ctx_t &ctx) const {
Expand Down
22 changes: 21 additions & 1 deletion src/gpu/intel/eltwise/ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,30 @@ struct ref_fwd_t : public primitive_t {
using gpu_eltwise_fwd_pd_t::gpu_eltwise_fwd_pd_t;

DECLARE_COMMON_PD_T("ocl:ref:any", ref_fwd_t);
status_t dropout_ok() const {
if (attr_.dropout_.has_default_values()) return status::success;

assert(memory_desc_wrapper(dst_md(0)).format_kind()
== format_kind::blocked);

using namespace format_tag;
// See `ref_dropout(...)` comment which explains the requirement.
VDISPATCH_ELTWISE_IC(memory_desc_matches_one_of_tag(
*dst_md(0), ncdhw, nchw, ncw, nc)
&& IMPLICATION(attr_.dropout_.has_output_mask(),
memory_desc_wrapper(dst_md(0)).similar_to(
attr_.dropout_.dropout_desc_, true,
false)),
VERBOSE_UNSUPPORTED_DROPOUT);

return status::success;
}

status_t init(impl::engine_t *engine) {
auto *intel_engine = utils::downcast<intel::engine_t *>(engine);

const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
const auto attr_skip_mask = (primitive_attr_t::skip_mask_t::post_ops
| primitive_attr_t::skip_mask_t::dropout);

using namespace alg_kind;
VDISPATCH_ELTWISE(is_fwd(), VERBOSE_BAD_PROPKIND);
Expand Down Expand Up @@ -68,6 +87,7 @@ struct ref_fwd_t : public primitive_t {
intel_engine->mayiuse(
compute::device_ext_t::khr_fp16)),
VERBOSE_UNSUPPORTED_DT_CFG);
CHECK(dropout_ok());
CHECK(init_conf(engine));
return status::success;
}
Expand Down
12 changes: 12 additions & 0 deletions src/gpu/intel/include/philox.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,16 @@ float stochastic_round_fwd(float s, long idx, uint seed) {
}
#endif

#if WITH_DROPOUT
// No need to enable fp64 extensions just to compute (double)p * 0xFFFFFFFFu
uint get_dropout_threshold(float p) {
if (p >= 1.f) return 0xFFFFFFFFu;
char exponent = 126 - ((as_uint(p) >> 23) & 0x7F);
if ((p <= 0.f) || (exponent > 31)) return 0u;
uint mantissa = (as_uint(p) << 8) | 0x80000000u;
if (!exponent) return (convert_ulong(mantissa) * 0xFFFFFFFFuL) >> 32;
return ((convert_ulong(mantissa >> exponent) * 0xFFFFFFFFuL) >> 32)
+ !!(mantissa & ((1u << exponent) - 1u));
}
#endif
#endif
15 changes: 1 addition & 14 deletions src/gpu/intel/matmul/ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,6 @@
((d0) * (s0) + (d1) * (s1) + (d2) * (s2) + (d3) * (s3) + (d4) * (s4) \
+ (d5) * (s5))

#if WITH_DROPOUT
// No need to enable fp64 extensions just to compute (double)p * 0xFFFFFFFFu
uint get_dropout_threshold(float p) {
if (p >= 1.f) return 0xFFFFFFFFu;
char exponent = 126 - ((as_uint(p) >> 23) & 0x7F);
if ((p <= 0.f) || (exponent > 31)) return 0u;
uint mantissa = (as_uint(p) << 8) | 0x80000000u;
if (!exponent) return (convert_ulong(mantissa) * 0xFFFFFFFFuL) >> 32;
return ((convert_ulong(mantissa >> exponent) * 0xFFFFFFFFuL) >> 32)
+ !!(mantissa & ((1u << exponent) - 1u));
}
#endif

__kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B,
__global DST_DATA_T *C, __global BIA_DATA_T *bia,
#if WITH_HOST_SRC_ZP
Expand Down Expand Up @@ -294,7 +281,7 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B,
float po_acc = convert_float(temp);

#if WITH_DROPOUT
#if USE_OFFSET
#if WITH_SEED_S64 && USE_OFFSET
uint res = philox_4x32_s64(
dst_off, (ulong)dropout_seed, (ulong)dropout_offset);
#else
Expand Down
3 changes: 3 additions & 0 deletions src/gpu/intel/matmul/ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,12 @@ struct ref_t : public primitive_t {
status_t init(impl::engine_t *engine) override {
compute::kernel_ctx_t kernel_ctx;

bool with_seed_s64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the status of non-s64 seed? it is still supported via API but not through benchdnn?

Copy link
Contributor Author

@h-sadia h-sadia Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thats true. the seed was previously s32 in the API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like something that should be enabled through benchdnn so we can test full range of supported features.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a deliberate choice not to test s32 in favor of s64. It's highly unlikely anybody will use this functionality besides PyTorch which has a requirement for int64 seed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keeping it as a keep sake for a while @kealan-barbieri ;)

= (pd()->attr()->dropout_.seed_dt_) == data_type::s64;
int ndims = pd()->dst_md()->ndims;
kernel_ctx.define_int("DST_NDIMS", ndims);
kernel_ctx.define_int("WITH_BIAS", pd()->with_bias());
kernel_ctx.define_int("WITH_SEED_S64", with_seed_s64);
kernel_ctx.define_int(
"WITH_DROPOUT", !pd()->attr()->dropout_.has_default_values());
kernel_ctx.define_int(
Expand Down
42 changes: 39 additions & 3 deletions src/gpu/intel/softmax/simple.cl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "gpu/intel/include/philox.h"
#include "gpu/intel/softmax/simple.h"

#if IS_FWD
Expand All @@ -24,7 +26,19 @@ __attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE)))

__kernel void
simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst,
__global float *src_scale, __global float *dst_scale POST_OP_ARGS) {
__global float *src_scale, __global float *dst_scale
#if WITH_DROPOUT
,
__global uchar *dropout_mask_buf,
#if USE_HOST_SCALARS
long dropout_seed, long dropout_offset,
float dropout_p
#else
__global long *dropout_seed_buf, __global long *dropout_offset_buf,
__global float *dropout_p_buf
#endif
#endif
POST_OP_ARGS) {

const int dim[] = {
(get_global_id(0) / GROUP_SIZE) % BLOCK_0,
Expand Down Expand Up @@ -85,7 +99,7 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst,
// finding max value for each sub_group
if (!(NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], begin))) {
for (int i = begin; i < end && i < DD(SOFTMAX_AXIS_IDX); ++i) {
size_t data_off
dim_t data_off
= DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i);
d[i - begin] = SRC_TO_REF(src[data_off]);
max_ = max(max_, d[i - begin]);
Expand Down Expand Up @@ -128,7 +142,7 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst,
#endif

for (int i = begin; i < end; ++i) {
size_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i);
dim_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i);

POST_OP_DATA_T tmp;
if (NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], i)) {
Expand Down Expand Up @@ -164,6 +178,28 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst,

#if WITH_DST_SCALES
tmp /= dst_scale[0];
#endif

#if WITH_DROPOUT
#if !USE_HOST_SCALARS
long dropout_seed = dropout_seed_buf[0];
long dropout_offset = USE_OFFSET ? dropout_offset_buf[0] : 0;
float dropout_p = dropout_p_buf[0];
#endif
uint dropout_threshold = get_dropout_threshold(dropout_p);
float dropout_inv_q
= (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f;
#if USE_OFFSET
uint res = philox_4x32_s64(
(ulong)data_off, (ulong)dropout_seed, (ulong)dropout_offset);
#else
uint res = philox_4x32(data_off, (uint)dropout_seed);
#endif
uchar dropout = res > dropout_threshold;
tmp = (dropout) ? tmp * dropout_inv_q : 0;
#if HAS_OUTPUT_MASK
dropout_mask_buf[data_off] = dropout;
#endif
#endif
dst[data_off] = TO_DST(tmp);
}
Expand Down
54 changes: 48 additions & 6 deletions src/gpu/intel/softmax/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,57 @@ status_t simple_fwd_t::execute_generic(const exec_ctx_t &ctx) const {
auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
auto &src_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
auto &dst_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
const bool with_dropout = !pd()->attr()->dropout_.has_default_values();

compute::kernel_arg_list_t arg_list;
arg_list.set(0, src);
arg_list.set(1, dst);
arg_list.set(2, src_scale);
arg_list.set(3, dst_scale);
append_post_ops_to_arg_list(
ctx, arg_list, 4, pd()->attr()->post_ops_, *pd()->dst_md());
int arg_idx = 0;
arg_list.set(arg_idx++, src);
arg_list.set(arg_idx++, dst);
arg_list.set(arg_idx++, src_scale);
arg_list.set(arg_idx++, dst_scale);
if (with_dropout) {
const bool use_host_scalars = pd()->attr()->dropout_.use_host_scalars_;
const bool use_offset = pd()->attr()->dropout_.use_offset_;

const auto &dropout_p
= CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_PROBABILITY);
const auto &dropout_seed = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_SEED);
const auto &dropout_offset
= CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_OFFSET);

arg_list.set(arg_idx++, CTX_OUT_STORAGE(DNNL_ARG_ATTR_DROPOUT_MASK));
if (use_host_scalars) {
int64_t scalar_seed = 0;
int64_t scalar_offset = 0;
float scalar_prob = 0.f;
const host_scalar_memory_storage_t *seed_storage
= utils::downcast<const host_scalar_memory_storage_t *>(
&dropout_seed);
CHECK(seed_storage->get_scalar_value(
&scalar_seed, sizeof(scalar_seed)));
if (use_offset) {
const host_scalar_memory_storage_t *offset_storage
= utils::downcast<const host_scalar_memory_storage_t *>(
&dropout_offset);
CHECK(offset_storage->get_scalar_value(
&scalar_offset, sizeof(scalar_offset)));
}
const host_scalar_memory_storage_t *prob_storage
= utils::downcast<const host_scalar_memory_storage_t *>(
&dropout_p);
CHECK(prob_storage->get_scalar_value(
&scalar_prob, sizeof(scalar_prob)));
arg_list.set(arg_idx++, scalar_seed);
arg_list.set(arg_idx++, scalar_offset);
arg_list.set(arg_idx++, scalar_prob);
} else {
arg_list.set(arg_idx++, dropout_seed);
arg_list.set(arg_idx++, dropout_offset);
arg_list.set(arg_idx++, dropout_p);
}
}
append_post_ops_to_arg_list(
ctx, arg_list, arg_idx++, pd()->attr()->post_ops_, *pd()->dst_md());
if (pd()->group_size > 1) {
auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws);
return parallel_for(ctx, nd_range, kernel_, arg_list);
Expand Down
Loading
Loading