Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflowing in intersection_intermediates.remove_if #1209

Merged
merged 15 commits into from
Jul 28, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,12 @@ linestring_intersection_result<T, index_t> pairwise_linestring_intersection(
stream);

points.remove_if(range(point_flags.begin(), point_flags.end()), stream);

rmm::device_uvector<int> point_flags_int(point_flags.size(), stream);
thrust::copy(
rmm::exec_policy(stream), point_flags.begin(), point_flags.end(), point_flags_int.begin());
}

// Phase 4: Assemble results as union column
auto num_union_column_rows = points.geoms->size() + segments.geoms->size();
auto geometry_collection_offsets =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ namespace detail {

namespace intersection_functors {

/**
* @brief Cast `uint8_t` to `X`
*
* @tparam X The type to cast to
*/
template <typename X>
struct uchar_to_x {
X __device__ operator()(uint8_t c) { return static_cast<X>(c); }
};

/** @brief Functor to compute the updated offset buffer after `remove_if` operation.
*
* Given the `i`th row in the `geometry_collection_offset`, find the number of all
Expand Down Expand Up @@ -292,11 +302,13 @@ struct linestring_intersection_intermediates {
rmm::device_uvector<index_t> reduced_flags(num_pairs(), stream);
auto keys_begin = make_geometry_id_iterator<index_t>(offsets->begin(), offsets->end());

auto iflags =
thrust::make_transform_iterator(flags.begin(), intersection_functors::uchar_to_x<index_t>{});
auto [keys_end, flags_end] =
thrust::reduce_by_key(rmm::exec_policy(stream),
keys_begin,
keys_begin + flags.size(),
flags.begin(),
iflags,
reduced_keys.begin(),
reduced_flags.begin(),
thrust::equal_to<index_t>(),
Expand Down
29 changes: 29 additions & 0 deletions cpp/tests/intersection/linestring_intersection_large_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2027,3 +2027,32 @@ TYPED_TEST(LinestringIntersectionLargeTest, LongInput)
CUSPATIAL_RUN_TEST(
this->template verify_legal_result, multilinestrings1.range(), multilinestrings2.range());
}

template <typename T>
struct coordinate_functor {
cuspatial::vec_2d<T> __device__ operator()(std::size_t i)
{
return cuspatial::vec_2d<T>{static_cast<T>(i), static_cast<T>(i)};
}
};

TYPED_TEST(LinestringIntersectionLargeTest, LongInput_2)
{
using P = cuspatial::vec_2d<TypeParam>;
auto geometry_offset = cuspatial::test::make_device_vector({0, 1});
auto part_offset = cuspatial::test::make_device_vector({0, 130});
auto coordinates = rmm::device_uvector<P>(260, this->stream());

thrust::tabulate(rmm::exec_policy(this->stream()),
coordinates.begin(),
thrust::next(coordinates.begin(), 128),
coordinate_functor<TypeParam>{});

coordinates.set_element(128, P{127.0, 0.0}, this->stream());
coordinates.set_element(129, P{0.0, 0.0}, this->stream());

auto rng = cuspatial::make_multilinestring_range(
1, geometry_offset.begin(), 1, part_offset.begin(), 130, coordinates.begin());

CUSPATIAL_RUN_TEST(this->template verify_legal_result, rng, rng);
}