@@ -225,105 +225,96 @@ kernel void nms<DTYPE ## 4, DTYPE>( \
225225 uint2 tgid [[threadgroup_position_in_grid]], \
226226 uint2 tid2 [[thread_position_in_threadgroup]]);
227227
228- template<typename T, typename integer_t>
228+ template <typename T, typename integer_t>
229229kernel void roi_align(
230230 constant T * input [[buffer(0)]],
231231 constant T * rois [[buffer(1)]],
232232 device T * output [[buffer(2)]],
233- constant int64_t & output_size [[buffer(3)]],
233+ constant float & spatial_scale [[buffer(3)]],
234234 constant int64_t & channels [[buffer(4)]],
235235 constant int64_t & height [[buffer(5)]],
236236 constant int64_t & width [[buffer(6)]],
237237 constant int64_t & pooled_height [[buffer(7)]],
238238 constant int64_t & pooled_width [[buffer(8)]],
239239 constant int64_t & sampling_ratio [[buffer(9)]],
240240 constant bool & aligned [[buffer(10)]],
241- constant float & spatial_scale [[buffer(11)]],
242- uint2 tgid [[threadgroup_position_in_grid]],
243- uint2 tptg [[threads_per_threadgroup]],
244- uint2 tid2 [[thread_position_in_threadgroup]]){
245- MPS_1D_KERNEL_LOOP(index, output_size, 1) {
246- // (n, c, ph, pw) is an element in the pooled output
247- integer_t pw = index % pooled_width;
248- integer_t ph = (index / pooled_width) % pooled_height;
249- integer_t c = (index / pooled_width / pooled_height) % channels;
250- integer_t n = index / pooled_width / pooled_height / channels;
251-
252- constant T* offset_rois = rois + n * 5;
253- integer_t roi_batch_ind = offset_rois[0];
254-
255- // Do not using rounding; this implementation detail is critical
256- T offset = aligned ? (T)0.5 : (T)0.0;
257- T roi_start_w = offset_rois[1] * spatial_scale - offset;
258- T roi_start_h = offset_rois[2] * spatial_scale - offset;
259- T roi_end_w = offset_rois[3] * spatial_scale - offset;
260- T roi_end_h = offset_rois[4] * spatial_scale - offset;
261-
262- T roi_width = roi_end_w - roi_start_w;
263- T roi_height = roi_end_h - roi_start_h;
264- if (!aligned) {
265- // Force malformed ROIs to be 1x1
266- roi_width = max(roi_width, (T)1.);
267- roi_height = max(roi_height, (T)1.);
268- }
269-
270- T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
271- T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
272-
273- constant T* offset_input =
274- input + (roi_batch_ind * channels + c) * height * width;
275-
276- // We use roi_bin_grid to sample the grid and mimic integral
277- integer_t roi_bin_grid_h = (sampling_ratio > 0)
278- ? sampling_ratio
279- : ceil(roi_height / pooled_height); // e.g., = 2
280- integer_t roi_bin_grid_w =
281- (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
282-
283- // We do average (integral) pooling inside a bin
284- // When the grid is empty, output zeros.
285- const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
286-
287- T output_val = 0.;
288- for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
289- {
290- const T y = roi_start_h + ph * bin_size_h +
291- static_cast<T>(iy + .5f) * bin_size_h /
292- static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
293- for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
294- const T x = roi_start_w + pw * bin_size_w +
295- static_cast<T>(ix + .5f) * bin_size_w /
296- static_cast<T>(roi_bin_grid_w);
241+ uint index [[thread_position_in_grid]])
242+ {
243+ // Decode linear index into (n, c, ph, pw)
244+ integer_t pw = index % pooled_width;
245+ integer_t ph = (index / pooled_width) % pooled_height;
246+ integer_t c = (index / pooled_width / pooled_height) % channels;
247+ integer_t n = index / (pooled_width * pooled_height * channels);
248+
249+ constant T* offset_rois = rois + n * 5;
250+ integer_t roi_batch_ind = static_cast<integer_t>(offset_rois[0]);
251+
252+ // Do not using rounding; this implementation detail is critical
253+ T offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.0);
254+ T roi_start_w = offset_rois[1] * spatial_scale - offset;
255+ T roi_start_h = offset_rois[2] * spatial_scale - offset;
256+ T roi_end_w = offset_rois[3] * spatial_scale - offset;
257+ T roi_end_h = offset_rois[4] * spatial_scale - offset;
258+
259+ T roi_width = roi_end_w - roi_start_w;
260+ T roi_height = roi_end_h - roi_start_h;
261+
262+ if (!aligned) {
263+ // Force malformed ROIs to be 1x1
264+ roi_width = max(roi_width, static_cast<T>(1.0));
265+ roi_height = max(roi_height, static_cast<T>(1.0));
266+ }
297267
298- T val = bilinear_interpolate(offset_input, height, width, y, x, index);
299- output_val += val;
300- }
268+ T bin_size_h = roi_height / static_cast<T>(pooled_height);
269+ T bin_size_w = roi_width / static_cast<T>(pooled_width);
270+
271+ constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width;
272+
273+ // We use roi_bin_grid to sample the grid and mimic integral
274+ integer_t roi_bin_grid_h = sampling_ratio > 0
275+ ? sampling_ratio
276+ : static_cast<integer_t>(ceil(roi_height / static_cast<T>(pooled_height)));
277+ integer_t roi_bin_grid_w = sampling_ratio > 0
278+ ? sampling_ratio
279+ : static_cast<integer_t>(ceil(roi_width / static_cast<T>(pooled_width)));
280+
281+ // We do average (integral) pooling inside a bin
282+ // When the grid is empty, output zeros.
283+ const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1));
284+ T output_val = static_cast<T>(0.0);
285+
286+ for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
287+ T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
288+ (static_cast<T>(iy) + static_cast<T>(0.5)) * bin_size_h / static_cast<T>(roi_bin_grid_h);
289+ for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
290+ T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
291+ (static_cast<T>(ix) + static_cast<T>(0.5)) * bin_size_w / static_cast<T>(roi_bin_grid_w);
292+
293+ T val = bilinear_interpolate(offset_input, height, width, y, x, index);
294+ output_val += val;
301295 }
302- output_val /= count;
303-
304- output[index] = output_val;
305296 }
297+
298+ output_val /= count;
299+ output[index] = output_val;
306300}
307301
308- #define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
309- template \
310- [[host_name("roi_align_" #DTYPE)]] \
311- kernel void roi_align<DTYPE, INT_DTYPE>( \
312- constant DTYPE * input [[buffer(0)]], \
313- constant DTYPE * rois [[buffer(1)]], \
314- device DTYPE * output [[buffer(2)]], \
315- constant int64_t & output_size [[buffer(3)]], \
316- constant int64_t & channels [[buffer(4)]], \
317- constant int64_t & height [[buffer(5)]], \
318- constant int64_t & width [[buffer(6)]], \
319- constant int64_t & pooled_height [[buffer(7)]], \
320- constant int64_t & pooled_width [[buffer(8)]], \
321- constant int64_t & sampling_ratio [[buffer(9)]], \
322- constant bool & aligned [[buffer(10)]], \
323- constant float & spatial_scale [[buffer(11)]], \
324- uint2 tgid [[threadgroup_position_in_grid]], \
325- uint2 tptg [[threads_per_threadgroup]], \
326- uint2 tid2 [[thread_position_in_threadgroup]]);
302+ #define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
303+ template \
304+ [[host_name("roi_align_" #DTYPE)]] \
305+ kernel void roi_align<DTYPE, INT_DTYPE>( \
306+ constant DTYPE * input [[buffer(0)]], \
307+ constant DTYPE * rois [[buffer(1)]], \
308+ device DTYPE * output [[buffer(2)]], \
309+ constant float & spatial_scale [[buffer(3)]], \
310+ constant int64_t & channels [[buffer(4)]], \
311+ constant int64_t & height [[buffer(5)]], \
312+ constant int64_t & width [[buffer(6)]], \
313+ constant int64_t & pooled_height [[buffer(7)]], \
314+ constant int64_t & pooled_width [[buffer(8)]], \
315+ constant int64_t & sampling_ratio [[buffer(9)]], \
316+ constant bool & aligned [[buffer(10)]], \
317+ uint index [[thread_position_in_grid]]);
327318
328319template<typename T, typename integer_t>
329320kernel void roi_align_backward(
@@ -1005,7 +996,7 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
1005996 constant int64_t & width [[buffer(7)]], \
1006997 constant int64_t & pooled_height [[buffer(8)]], \
1007998 constant int64_t & pooled_width [[buffer(9)]], \
1008- constant int64_t & channels_out [[buffer(10)]], \
999+ constant int64_t & channels_out [[buffer(10)]], \
10091000 constant float & spatial_scale [[buffer(11)]], \
10101001 uint2 tgid [[threadgroup_position_in_grid]], \
10111002 uint2 tptg [[threads_per_threadgroup]], \
0 commit comments