Skip to content

Commit 87d9beb

Browse files
Fix segmentation fault in NLLLoss kernel (#2111)
Fixed the following issues found by test/test_nn.py::TestNNDeviceTypeXPU::test_nll_loss_large_tensor_reduction_mean_xpu and test_nll_loss_large_tensor_reduction_sum_xpu 1. Segmentation faults caused by pointer type conversion errors that result in invalid memory addresses. 2. Kernel call errors caused by incorrect judgment conditions. --------- Co-authored-by: mengfei25 <[email protected]>
1 parent 8d373ba commit 87d9beb

File tree

4 files changed

+563
-632
lines changed

4 files changed

+563
-632
lines changed

src/ATen/native/xpu/LossNLL.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,9 @@ TORCH_IMPL_FUNC(nll_loss_forward_out_xpu)
1818
int64_t ignore_index,
1919
const Tensor& output,
2020
const Tensor& total_weight) {
21+
const Tensor& weight = weight_opt.getTensorRef();
2122
xpu::nll_loss_forward_kernel(
22-
self,
23-
target,
24-
((weight_opt.has_value() && (*weight_opt).defined())
25-
? at::OptionalTensorRef(*weight_opt)
26-
: at::OptionalTensorRef()),
27-
reduction,
28-
ignore_index,
29-
output,
30-
total_weight);
23+
output, total_weight, self, target, weight, reduction, ignore_index);
3124
}
3225

3326
TORCH_IMPL_FUNC(nll_loss_backward_out_xpu)
@@ -39,19 +32,18 @@ TORCH_IMPL_FUNC(nll_loss_backward_out_xpu)
3932
int64_t ignore_index,
4033
const Tensor& total_weight,
4134
const Tensor& grad_input) {
35+
const Tensor& weight = weight_opt.getTensorRef();
4236
grad_input.zero_();
4337
xpu::nll_loss_backward_kernel(
38+
grad_input,
4439
grad_output,
4540
self,
4641
target,
47-
((weight_opt.has_value() && (*weight_opt).defined())
48-
? at::OptionalTensorRef(*weight_opt)
49-
: at::OptionalTensorRef()),
50-
reduction,
51-
ignore_index,
5242
total_weight,
53-
grad_input);
43+
weight,
44+
reduction,
45+
ignore_index);
5446
}
5547

5648
} // namespace native
57-
} // namespace at
49+
} // namespace at

src/ATen/native/xpu/sycl/KernelUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,20 @@
1111
i = _i_n_d_e_x)
1212

1313
#define XPU_KERNEL_LOOP(item, i, n) XPU_KERNEL_LOOP_TYPE(item, i, n, int)
14+
15+
constexpr int SYCL_NUM_THREADS = 1024;
16+
17+
inline int GET_GROUPS(
18+
const int64_t N,
19+
const int64_t max_threads_per_group = SYCL_NUM_THREADS) {
20+
TORCH_INTERNAL_ASSERT(
21+
N > 0, "XPU kernel launch blocks must be positive, but got N=", N);
22+
constexpr int64_t max_int = std::numeric_limits<int>::max();
23+
24+
// Round up division for positive number that cannot cause integer overflow
25+
auto group_num = (N - 1) / max_threads_per_group + 1;
26+
TORCH_INTERNAL_ASSERT(
27+
group_num <= max_int, "Can't schedule too many blocks on XPU device");
28+
29+
return static_cast<int>(group_num);
30+
}

0 commit comments

Comments
 (0)