Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/layout/legacy/barnes_hut.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ void barnes_hut(raft::handle_t const& handle,
old_forces,
old_forces + n,
swinging,
prevent_overlapping,
vertex_mobility,
speed,
n);
Expand Down
16 changes: 12 additions & 4 deletions cpp/src/layout/legacy/bh_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,12 @@ __global__ static __launch_bounds__(THREADS6, FACTOR6) void apply_forces_bh(
float* restrict old_dx,
float* restrict old_dy,
const float* restrict swinging,
const bool prevent_overlapping,
const float* restrict vertex_mobility,
const float speed,
const int n)
{
// For evrery vertex
// For every vertex
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) {
// Store displacement needed for next iteration.
const float dx = (repel_x[i] + attract_x[i]);
Expand All @@ -642,9 +643,16 @@ __global__ static __launch_bounds__(THREADS6, FACTOR6) void apply_forces_bh(

// Update positions
float mobility_factor = vertex_mobility ? vertex_mobility[i] : 1.0f;
float factor = mobility_factor * speed / (1.0 + sqrt(speed * swinging[i]));
Y_x[i] += dx * factor;
Y_y[i] += dy * factor;
float factor = speed / (1.0 + sqrt(speed * swinging[i]));

if (prevent_overlapping) {
factor = 0.1 * factor;
float df = sqrt(dx * dx + dy * dy);
factor = min(factor * df, 10.0f) / df;
}

Y_x[i] += dx * mobility_factor * factor;
Y_y[i] += dy * mobility_factor * factor;
}
}

Expand Down
1 change: 1 addition & 0 deletions cpp/src/layout/legacy/exact_fa2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ void exact_fa2(raft::handle_t const& handle,
d_old_forces,
d_old_forces + n,
d_swinging,
prevent_overlapping,
vertex_mobility,
speed,
n,
Expand Down
16 changes: 13 additions & 3 deletions cpp/src/layout/legacy/fa2_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,27 @@ __global__ static void update_positions_kernel(float* restrict x_pos,
float* restrict old_dx,
float* restrict old_dy,
const float* restrict swinging,
const bool prevent_overlapping,
const float* restrict vertex_mobility,
const float speed,
const vertex_t n)
{
// For every node.
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += gridDim.x * blockDim.x) {
const float mobility_factor = vertex_mobility ? vertex_mobility[i] : 1.0f;
const float factor = mobility_factor * speed / (1.0 + sqrt(speed * swinging[i]));
const float dx = (repel_x[i] + attract_x[i]);
const float dy = (repel_y[i] + attract_y[i]);

x_pos[i] += dx * factor;
y_pos[i] += dy * factor;
float factor = speed / (1.0 + sqrt(speed * swinging[i]));

if (prevent_overlapping) {
factor = 0.1 * factor;
float df = sqrt(dx * dx + dy * dy);
factor = min(factor * df, 10.0f) / df;
}

x_pos[i] += dx * mobility_factor * factor;
y_pos[i] += dy * mobility_factor * factor;
old_dx[i] = dx;
old_dy[i] = dy;
}
Expand All @@ -341,6 +349,7 @@ void apply_forces(float* restrict x_pos,
float* restrict old_dx,
float* restrict old_dy,
const float* restrict swinging,
const bool prevent_overlapping,
const float* restrict vertex_mobility,
const float speed,
const vertex_t n,
Expand All @@ -365,6 +374,7 @@ void apply_forces(float* restrict x_pos,
old_dx,
old_dy,
swinging,
prevent_overlapping,
vertex_mobility,
speed,
n);
Expand Down
Loading