diff --git a/src/gpu/intel/eltwise/ref.cl b/src/gpu/intel/eltwise/ref.cl index 9a9c556c3af..782318639f9 100644 --- a/src/gpu/intel/eltwise/ref.cl +++ b/src/gpu/intel/eltwise/ref.cl @@ -17,6 +17,7 @@ #include "gpu/intel/include/dispatch.h" #include "gpu/intel/include/eltwise.h" #include "gpu/intel/include/io.h" +#include "gpu/intel/include/philox.h" #include "gpu/intel/include/post_ops.h" #include "gpu/intel/include/types_interop.h" @@ -29,8 +30,19 @@ #if IS_FWD __kernel void ref_eltwise_fwd(__global SRC_DATA_T *src, - __global DST_DATA_T *dst, float alpha, float beta, - int64x3_t offset POST_OP_ARGS) { + __global DST_DATA_T *dst, float alpha, float beta, int64x3_t offset +#if WITH_DROPOUT + , + __global uchar *dropout_mask_buf, +#if USE_HOST_SCALARS + long dropout_seed, long dropout_offset, + float dropout_p +#else + __global long *dropout_seed_buf, __global long *dropout_offset_buf, + __global float *dropout_p_buf +#endif +#endif + POST_OP_ARGS) { #if USE_GWS_GET dim_t d0 = GWS_GET_D0(); dim_t d1 = GWS_GET_D1(); @@ -77,6 +89,28 @@ __kernel void ref_eltwise_fwd(__global SRC_DATA_T *src, #endif APPLY_POST_OPS_SERIAL(tmp_s, dst_data, d0, d1, d2, d3, d4, d5); + +#if WITH_DROPOUT +#if !USE_HOST_SCALARS + long dropout_seed = dropout_seed_buf[0]; + long dropout_offset = USE_OFFSET ? dropout_offset_buf[0] : 0; + float dropout_p = dropout_p_buf[0]; +#endif + uint dropout_threshold = get_dropout_threshold(dropout_p); + float dropout_inv_q = (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f; +#if USE_OFFSET + uint res = philox_4x32_s64( + (ulong)data_off, (ulong)dropout_seed, (ulong)dropout_offset); +#else + uint res = philox_4x32(data_off, (uint)dropout_seed); +#endif + uchar dropout = res > dropout_threshold; + tmp_s = (dropout) ? tmp_s * dropout_inv_q : 0; +#if HAS_OUTPUT_MASK + dropout_mask_buf[data_off] = dropout; +#endif +#endif + write(dst + data_off, tmp_s); } diff --git a/src/gpu/intel/eltwise/ref.cpp b/src/gpu/intel/eltwise/ref.cpp index ee9e4cb3506..35b3e26098d 100644 --- a/src/gpu/intel/eltwise/ref.cpp +++ b/src/gpu/intel/eltwise/ref.cpp @@ -66,7 +66,7 @@ static status_t init_conf_common( static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, const conf_t &conf, const post_ops_t &post_ops, - const memory_desc_t *dst_md) { + const dropout_t &dropout, const memory_desc_t *dst_md) { kernel_ctx.set_data_type(conf.data_type); kernel_ctx.require_stateless_addressing(conf.require_stateless_addressing); @@ -76,6 +76,10 @@ static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, kernel_ctx.define_int("GWS1", conf.dispatch.nd_range().global_range()[1]); kernel_ctx.define_int("GWS2", conf.dispatch.nd_range().global_range()[2]); kernel_ctx.define_int("USE_CUSTOM_GWS_GET_ID", 1); + kernel_ctx.define_int("WITH_DROPOUT", !dropout.has_default_values()); + kernel_ctx.define_int("USE_HOST_SCALARS", dropout.use_host_scalars_); + kernel_ctx.define_int("USE_OFFSET", dropout.use_offset_); + kernel_ctx.define_int("HAS_OUTPUT_MASK", dropout.has_output_mask()); bool with_binary_post_ops = post_ops.find(primitive_kind_t::dnnl_binary) != -1; @@ -103,8 +107,8 @@ status_t ref_fwd_t::pd_t::init_conf(impl::engine_t *engine) { status_t ref_fwd_t::pd_t::init_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { - return init_kernel_ctx_common( - kernel_ctx, conf, attr()->post_ops_, invariant_dst_md()); + return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_, + attr()->dropout_, invariant_dst_md()); } status_t ref_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const { @@ -122,9 +126,52 @@ status_t ref_fwd_t::execute_forward_dense(const exec_ctx_t &ctx) const { arg_list.set(1, dst); arg_list.set(2, alpha); arg_list.set(3, beta); + const bool with_dropout = !pd()->attr()->dropout_.has_default_values(); + + int arg_idx = 5; + if (with_dropout) { + const bool use_host_scalars = pd()->attr()->dropout_.use_host_scalars_; + const bool use_offset = pd()->attr()->dropout_.use_offset_; + + const auto &dropout_p + = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_PROBABILITY); + const auto &dropout_seed = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_SEED); + const auto &dropout_offset + = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_OFFSET); + arg_list.set(arg_idx++, CTX_OUT_STORAGE(DNNL_ARG_ATTR_DROPOUT_MASK)); + if (use_host_scalars) { + int64_t scalar_seed = 0; + int64_t scalar_offset = 0; + float scalar_prob = 0.f; + const host_scalar_memory_storage_t *seed_storage + = utils::downcast( + &dropout_seed); + CHECK(seed_storage->get_scalar_value( + &scalar_seed, sizeof(scalar_seed))); + if (use_offset) { + const host_scalar_memory_storage_t *offset_storage + = utils::downcast( + &dropout_offset); + CHECK(offset_storage->get_scalar_value( + &scalar_offset, sizeof(scalar_offset))); + } + const host_scalar_memory_storage_t *prob_storage + = utils::downcast( + &dropout_p); + CHECK(prob_storage->get_scalar_value( + &scalar_prob, sizeof(scalar_prob))); + arg_list.set(arg_idx++, scalar_seed); + arg_list.set(arg_idx++, scalar_offset); + arg_list.set(arg_idx++, scalar_prob); + } else { + arg_list.set(arg_idx++, dropout_seed); + arg_list.set(arg_idx++, dropout_offset); + arg_list.set(arg_idx++, dropout_p); + } + } append_post_ops_to_arg_list( - ctx, arg_list, 5, pd()->attr()->post_ops_, *pd()->dst_md()); + ctx, arg_list, arg_idx, pd()->attr()->post_ops_, *pd()->dst_md()); auto nd_range = conf.dispatch.nd_range(); return large_parallel_for(ctx, nd_range, kernel_, arg_list, 4); @@ -136,8 +183,8 @@ status_t ref_bwd_t::pd_t::init_conf(impl::engine_t *engine) { status_t ref_bwd_t::pd_t::init_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { - return init_kernel_ctx_common( - kernel_ctx, conf, attr()->post_ops_, invariant_dst_md()); + return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_, + attr()->dropout_, invariant_dst_md()); } status_t ref_bwd_t::execute_backward_dense(const exec_ctx_t &ctx) const { diff --git a/src/gpu/intel/eltwise/ref.hpp b/src/gpu/intel/eltwise/ref.hpp index 281f8fff08d..c84d731da14 100644 --- a/src/gpu/intel/eltwise/ref.hpp +++ b/src/gpu/intel/eltwise/ref.hpp @@ -35,11 +35,30 @@ struct ref_fwd_t : public primitive_t { using gpu_eltwise_fwd_pd_t::gpu_eltwise_fwd_pd_t; DECLARE_COMMON_PD_T("ocl:ref:any", ref_fwd_t); + status_t dropout_ok() const { + if (attr_.dropout_.has_default_values()) return status::success; + + assert(memory_desc_wrapper(dst_md(0)).format_kind() + == format_kind::blocked); + + using namespace format_tag; + // See `ref_dropout(...)` comment which explains the requirement. + VDISPATCH_ELTWISE_IC(memory_desc_matches_one_of_tag( + *dst_md(0), ncdhw, nchw, ncw, nc) + && IMPLICATION(attr_.dropout_.has_output_mask(), + memory_desc_wrapper(dst_md(0)).similar_to( + attr_.dropout_.dropout_desc_, true, + false)), + VERBOSE_UNSUPPORTED_DROPOUT); + + return status::success; + } status_t init(impl::engine_t *engine) { auto *intel_engine = utils::downcast(engine); - const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops; + const auto attr_skip_mask = (primitive_attr_t::skip_mask_t::post_ops + | primitive_attr_t::skip_mask_t::dropout); using namespace alg_kind; VDISPATCH_ELTWISE(is_fwd(), VERBOSE_BAD_PROPKIND); @@ -68,6 +87,7 @@ struct ref_fwd_t : public primitive_t { intel_engine->mayiuse( compute::device_ext_t::khr_fp16)), VERBOSE_UNSUPPORTED_DT_CFG); + CHECK(dropout_ok()); CHECK(init_conf(engine)); return status::success; } diff --git a/src/gpu/intel/include/philox.h b/src/gpu/intel/include/philox.h index 2926a114923..031117fac26 100644 --- a/src/gpu/intel/include/philox.h +++ b/src/gpu/intel/include/philox.h @@ -89,4 +89,16 @@ float stochastic_round_fwd(float s, long idx, uint seed) { } #endif +#if WITH_DROPOUT +// No need to enable fp64 extensions just to compute (double)p * 0xFFFFFFFFu +uint get_dropout_threshold(float p) { + if (p >= 1.f) return 0xFFFFFFFFu; + char exponent = 126 - ((as_uint(p) >> 23) & 0x7F); + if ((p <= 0.f) || (exponent > 31)) return 0u; + uint mantissa = (as_uint(p) << 8) | 0x80000000u; + if (!exponent) return (convert_ulong(mantissa) * 0xFFFFFFFFuL) >> 32; + return ((convert_ulong(mantissa >> exponent) * 0xFFFFFFFFuL) >> 32) + + !!(mantissa & ((1u << exponent) - 1u)); +} +#endif #endif diff --git a/src/gpu/intel/matmul/ref.cl b/src/gpu/intel/matmul/ref.cl index 863730e4690..16f4c490881 100644 --- a/src/gpu/intel/matmul/ref.cl +++ b/src/gpu/intel/matmul/ref.cl @@ -22,19 +22,6 @@ ((d0) * (s0) + (d1) * (s1) + (d2) * (s2) + (d3) * (s3) + (d4) * (s4) \ + (d5) * (s5)) -#if WITH_DROPOUT -// No need to enable fp64 extensions just to compute (double)p * 0xFFFFFFFFu -uint get_dropout_threshold(float p) { - if (p >= 1.f) return 0xFFFFFFFFu; - char exponent = 126 - ((as_uint(p) >> 23) & 0x7F); - if ((p <= 0.f) || (exponent > 31)) return 0u; - uint mantissa = (as_uint(p) << 8) | 0x80000000u; - if (!exponent) return (convert_ulong(mantissa) * 0xFFFFFFFFuL) >> 32; - return ((convert_ulong(mantissa >> exponent) * 0xFFFFFFFFuL) >> 32) - + !!(mantissa & ((1u << exponent) - 1u)); -} -#endif - __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, __global DST_DATA_T *C, __global BIA_DATA_T *bia, #if WITH_HOST_SRC_ZP @@ -294,7 +281,7 @@ __kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, float po_acc = convert_float(temp); #if WITH_DROPOUT -#if USE_OFFSET +#if WITH_SEED_S64 && USE_OFFSET uint res = philox_4x32_s64( dst_off, (ulong)dropout_seed, (ulong)dropout_offset); #else diff --git a/src/gpu/intel/matmul/ref.hpp b/src/gpu/intel/matmul/ref.hpp index 1b4a06350e7..5fef40c5f7c 100644 --- a/src/gpu/intel/matmul/ref.hpp +++ b/src/gpu/intel/matmul/ref.hpp @@ -247,9 +247,12 @@ struct ref_t : public primitive_t { status_t init(impl::engine_t *engine) override { compute::kernel_ctx_t kernel_ctx; + bool with_seed_s64 + = (pd()->attr()->dropout_.seed_dt_) == data_type::s64; int ndims = pd()->dst_md()->ndims; kernel_ctx.define_int("DST_NDIMS", ndims); kernel_ctx.define_int("WITH_BIAS", pd()->with_bias()); + kernel_ctx.define_int("WITH_SEED_S64", with_seed_s64); kernel_ctx.define_int( "WITH_DROPOUT", !pd()->attr()->dropout_.has_default_values()); kernel_ctx.define_int( diff --git a/src/gpu/intel/softmax/simple.cl b/src/gpu/intel/softmax/simple.cl index 0b65b0bbb65..3abce105b66 100644 --- a/src/gpu/intel/softmax/simple.cl +++ b/src/gpu/intel/softmax/simple.cl @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ + +#include "gpu/intel/include/philox.h" #include "gpu/intel/softmax/simple.h" #if IS_FWD @@ -24,7 +26,19 @@ __attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __kernel void simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst, - __global float *src_scale, __global float *dst_scale POST_OP_ARGS) { + __global float *src_scale, __global float *dst_scale +#if WITH_DROPOUT + , + __global uchar *dropout_mask_buf, +#if USE_HOST_SCALARS + long dropout_seed, long dropout_offset, + float dropout_p +#else + __global long *dropout_seed_buf, __global long *dropout_offset_buf, + __global float *dropout_p_buf +#endif +#endif + POST_OP_ARGS) { const int dim[] = { (get_global_id(0) / GROUP_SIZE) % BLOCK_0, @@ -85,7 +99,7 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst, // finding max value for each sub_group if (!(NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], begin))) { for (int i = begin; i < end && i < DD(SOFTMAX_AXIS_IDX); ++i) { - size_t data_off + dim_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); d[i - begin] = SRC_TO_REF(src[data_off]); max_ = max(max_, d[i - begin]); @@ -128,7 +142,7 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst, #endif for (int i = begin; i < end; ++i) { - size_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); + dim_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); POST_OP_DATA_T tmp; if (NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], i)) { @@ -164,6 +178,28 @@ simple_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst, #if WITH_DST_SCALES tmp /= dst_scale[0]; +#endif + +#if WITH_DROPOUT +#if !USE_HOST_SCALARS + long dropout_seed = dropout_seed_buf[0]; + long dropout_offset = USE_OFFSET ? dropout_offset_buf[0] : 0; + float dropout_p = dropout_p_buf[0]; +#endif + uint dropout_threshold = get_dropout_threshold(dropout_p); + float dropout_inv_q + = (dropout_p != 1.f) ? 1.f / (1.f - dropout_p) : 0.f; +#if USE_OFFSET + uint res = philox_4x32_s64( + (ulong)data_off, (ulong)dropout_seed, (ulong)dropout_offset); +#else + uint res = philox_4x32(data_off, (uint)dropout_seed); +#endif + uchar dropout = res > dropout_threshold; + tmp = (dropout) ? tmp * dropout_inv_q : 0; +#if HAS_OUTPUT_MASK + dropout_mask_buf[data_off] = dropout; +#endif #endif dst[data_off] = TO_DST(tmp); } diff --git a/src/gpu/intel/softmax/simple.cpp b/src/gpu/intel/softmax/simple.cpp index e5328ac183b..14f0c8923c6 100644 --- a/src/gpu/intel/softmax/simple.cpp +++ b/src/gpu/intel/softmax/simple.cpp @@ -29,15 +29,57 @@ status_t simple_fwd_t::execute_generic(const exec_ctx_t &ctx) const { auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); auto &src_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); auto &dst_scale = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); + const bool with_dropout = !pd()->attr()->dropout_.has_default_values(); compute::kernel_arg_list_t arg_list; - arg_list.set(0, src); - arg_list.set(1, dst); - arg_list.set(2, src_scale); - arg_list.set(3, dst_scale); - append_post_ops_to_arg_list( - ctx, arg_list, 4, pd()->attr()->post_ops_, *pd()->dst_md()); + int arg_idx = 0; + arg_list.set(arg_idx++, src); + arg_list.set(arg_idx++, dst); + arg_list.set(arg_idx++, src_scale); + arg_list.set(arg_idx++, dst_scale); + if (with_dropout) { + const bool use_host_scalars = pd()->attr()->dropout_.use_host_scalars_; + const bool use_offset = pd()->attr()->dropout_.use_offset_; + + const auto &dropout_p + = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_PROBABILITY); + const auto &dropout_seed = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_SEED); + const auto &dropout_offset + = CTX_IN_STORAGE(DNNL_ARG_ATTR_DROPOUT_OFFSET); + arg_list.set(arg_idx++, CTX_OUT_STORAGE(DNNL_ARG_ATTR_DROPOUT_MASK)); + if (use_host_scalars) { + int64_t scalar_seed = 0; + int64_t scalar_offset = 0; + float scalar_prob = 0.f; + const host_scalar_memory_storage_t *seed_storage + = utils::downcast( + &dropout_seed); + CHECK(seed_storage->get_scalar_value( + &scalar_seed, sizeof(scalar_seed))); + if (use_offset) { + const host_scalar_memory_storage_t *offset_storage + = utils::downcast( + &dropout_offset); + CHECK(offset_storage->get_scalar_value( + &scalar_offset, sizeof(scalar_offset))); + } + const host_scalar_memory_storage_t *prob_storage + = utils::downcast( + &dropout_p); + CHECK(prob_storage->get_scalar_value( + &scalar_prob, sizeof(scalar_prob))); + arg_list.set(arg_idx++, scalar_seed); + arg_list.set(arg_idx++, scalar_offset); + arg_list.set(arg_idx++, scalar_prob); + } else { + arg_list.set(arg_idx++, dropout_seed); + arg_list.set(arg_idx++, dropout_offset); + arg_list.set(arg_idx++, dropout_p); + } + } + append_post_ops_to_arg_list( + ctx, arg_list, arg_idx++, pd()->attr()->post_ops_, *pd()->dst_md()); if (pd()->group_size > 1) { auto nd_range = compute::nd_range_t(pd()->gws, pd()->lws); return parallel_for(ctx, nd_range, kernel_, arg_list); diff --git a/src/gpu/intel/softmax/simple.hpp b/src/gpu/intel/softmax/simple.hpp index a18b240fac7..43b0ce87b9f 100644 --- a/src/gpu/intel/softmax/simple.hpp +++ b/src/gpu/intel/softmax/simple.hpp @@ -39,6 +39,25 @@ struct simple_fwd_t : public primitive_t { {primitive_kind::eltwise, primitive_kind::binary}); } + status_t dropout_ok() const { + if (attr_.dropout_.has_default_values()) return status::success; + + assert(memory_desc_wrapper(dst_md(0)).format_kind() + == format_kind::blocked); + + using namespace format_tag; + // See `ref_dropout(...)` comment which explains the requirement. + VDISPATCH_SOFTMAX_IC(memory_desc_matches_one_of_tag( + *dst_md(0), ncdhw, nchw, ncw, nc) + && IMPLICATION(attr_.dropout_.has_output_mask(), + memory_desc_wrapper(dst_md(0)).similar_to( + attr_.dropout_.dropout_desc_, true, + false)), + VERBOSE_UNSUPPORTED_DROPOUT); + + return status::success; + } + status_t init(impl::engine_t *engine) { auto *intel_engine = utils::downcast(engine); @@ -75,8 +94,9 @@ struct simple_fwd_t : public primitive_t { VDISPATCH_SOFTMAX(memory_desc_ndims_ok(src_md(), dst_md()), VERBOSE_INCONSISTENT_NDIMS_WITH_VALS, "src", "dst", src_md()->ndims, dst_md()->ndims); - VDISPATCH_SOFTMAX(attr()->has_default_values(skip_mask_t::scales - | skip_mask_t::post_ops), + VDISPATCH_SOFTMAX( + attr()->has_default_values(skip_mask_t::scales + | skip_mask_t::post_ops | skip_mask_t::dropout), VERBOSE_UNSUPPORTED_ATTR); VDISPATCH_SOFTMAX(is_not_double_blk, VERBOSE_UNSUPPORTED_TAG); VDISPATCH_SOFTMAX(attr_scales_ok(), VERBOSE_UNSUPPORTED_SCALES_CFG); @@ -85,6 +105,7 @@ struct simple_fwd_t : public primitive_t { set_default_formats(), VERBOSE_UNSUPPORTED_TAG); VDISPATCH_SOFTMAX_SC(attr_.set_default_formats(dst_md(0)), VERBOSE_UNSUPPORTED_POSTOP); + CHECK(dropout_ok()); dim_t nelems = axis_size(true); @@ -138,6 +159,13 @@ struct simple_fwd_t : public primitive_t { kernel_ctx.define_int("GROUP_SIZE", pd()->group_size); kernel_ctx.define_int("SUB_GROUP_SIZE", pd()->subgroup_size); kernel_ctx.define_int("IS_FWD", 1); + kernel_ctx.define_int( + "WITH_DROPOUT", !pd()->attr()->dropout_.has_default_values()); + kernel_ctx.define_int( + "USE_HOST_SCALARS", pd()->attr()->dropout_.use_host_scalars_); + kernel_ctx.define_int("USE_OFFSET", pd()->attr()->dropout_.use_offset_); + kernel_ctx.define_int( + "HAS_OUTPUT_MASK", pd()->attr()->dropout_.has_output_mask()); kernel_ctx.add_option("-cl-std=CL2.0"); kernel_ctx.define_int("SOFTMAX_INF_AS_ZERO", pd()->alg_kind() == alg_kind::softmax_accurate_inf_as_zero); diff --git a/tests/benchdnn/eltwise/eltwise.cpp b/tests/benchdnn/eltwise/eltwise.cpp index b3ff5038ae0..57aeee0f894 100644 --- a/tests/benchdnn/eltwise/eltwise.cpp +++ b/tests/benchdnn/eltwise/eltwise.cpp @@ -298,16 +298,6 @@ void skip_unimplemented_prb(const prb_t *prb, res_t *res) { res->reason = skip_reason::data_type_not_supported; return; } - - if (is_gpu()) { - if (!prb->attr.dropout.is_def()) { - BENCHDNN_PRINT(2, "[SKIP][%s:%d]: Dropout isn't supported.\n", - __FILE__, __LINE__); - res->state = SKIPPED; - res->reason = skip_reason::case_not_supported; - return; - } - } } void skip_invalid_prb(const prb_t *prb, res_t *res) { diff --git a/tests/benchdnn/inputs/eltwise/harness_eltwise_dropout b/tests/benchdnn/inputs/eltwise/harness_eltwise_dropout index 7cffbb3e25f..0922280058c 100644 --- a/tests/benchdnn/inputs/eltwise/harness_eltwise_dropout +++ b/tests/benchdnn/inputs/eltwise/harness_eltwise_dropout @@ -4,7 +4,7 @@ --dt=f32,bf16 --tag=abx ---attr-dropout=0.5:12345678,0.75:12345678:undef,0.25:843921:any:1238976:true +--attr-dropout=0.5:12345678,0.75:12345678:undef,0.25:843921:any:1238976:true,0.75:111786:any:121716:false --alpha=0 --beta=0 --alg=exp,exp_dst diff --git a/tests/benchdnn/inputs/softmax/harness_softmax_dropout b/tests/benchdnn/inputs/softmax/harness_softmax_dropout index d4638c5c3d6..a20aa34fe2a 100644 --- a/tests/benchdnn/inputs/softmax/harness_softmax_dropout +++ b/tests/benchdnn/inputs/softmax/harness_softmax_dropout @@ -4,7 +4,7 @@ --sdt=f32,bf16 --ddt=f32,bf16 ---attr-dropout=0.5:12345678,0.75:12345678:undef,0.25:843921:any:1238976:true +--attr-dropout=0.5:12345678,0.75:12345678:undef,0.25:843921:any:1238976:true,0.75:111786:any:121716:false --stag=abx --dtag=any --batch=shapes_ci diff --git a/tests/benchdnn/softmax/softmax.cpp b/tests/benchdnn/softmax/softmax.cpp index b20cdf7d634..cab9ac98e9d 100644 --- a/tests/benchdnn/softmax/softmax.cpp +++ b/tests/benchdnn/softmax/softmax.cpp @@ -247,16 +247,6 @@ void skip_unimplemented_prb(const prb_t *prb, res_t *res) { res->state = DEFERRED; return; } - - if (is_gpu()) { - if (!prb->attr.dropout.is_def()) { - BENCHDNN_PRINT(2, "[SKIP][%s:%d]: Dropout isn't supported.\n", - __FILE__, __LINE__); - res->state = SKIPPED; - res->reason = skip_reason::case_not_supported; - return; - } - } } void skip_invalid_prb(const prb_t *prb, res_t *res) {