Skip to content

Commit 143d4a1

Browse files
[JAX SC] Refactor SortAndGroup for readability and reduce nesting.
PiperOrigin-RevId: 825779129
1 parent a19e081 commit 143d4a1

File tree

1 file changed

+105
-77
lines changed

1 file changed

+105
-77
lines changed

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 105 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,69 @@ inline void UpdateMaxIdsPerPartition(BlockRow<int>& global_max,
107107
global_max.cwiseMax(local_max_per_bucket.rowwise().sum().transpose());
108108
}
109109

110+
template <typename SplitType>
111+
inline void UpdateMinibatchingSplit(
112+
MatrixXi& ids_per_sc_partition_per_bucket,
113+
MatrixXi& unique_ids_per_partition_per_bucket,
114+
const int32_t global_sc_count, const int32_t max_ids_per_partition,
115+
const int32_t max_unique_ids_per_partition, SplitType& minibatching_split) {
116+
// This works both when minibatching is required and not. In the former
117+
// case we have bool which tells us if minibatching is required,
118+
// in the latter case it is std::bitset<64> which tells us the exact
119+
// splits.
120+
for (int global_sc_id = 0; global_sc_id < global_sc_count; ++global_sc_id) {
121+
auto ids_per_bucket =
122+
ids_per_sc_partition_per_bucket.row(global_sc_id).array();
123+
auto unique_ids_per_bucket =
124+
unique_ids_per_partition_per_bucket.row(global_sc_id).array();
125+
if constexpr (std::is_same_v<SplitType, MinibatchingSplit>) {
126+
// The arrays must be mutable as ComputeMinibatchingSplit modifies them.
127+
// absl::Makespan works because the array would be row-major and
128+
// values would be contiguous in memory.
129+
static_assert(decltype(ids_per_bucket)::IsRowMajor);
130+
static_assert(decltype(unique_ids_per_bucket)::IsRowMajor);
131+
// NOTE: ComputeSplit modifies the span, but we have already updated
132+
// the output stats.
133+
minibatching_split |= ComputeMinibatchingSplit(
134+
absl::MakeSpan(ids_per_bucket.data(), ids_per_bucket.size()),
135+
max_ids_per_partition);
136+
minibatching_split |=
137+
ComputeMinibatchingSplit(absl::MakeSpan(unique_ids_per_bucket.data(),
138+
unique_ids_per_bucket.size()),
139+
max_unique_ids_per_partition);
140+
} else {
141+
minibatching_split |=
142+
unique_ids_per_bucket.maxCoeff() > max_unique_ids_per_partition ||
143+
ids_per_bucket.maxCoeff() > max_ids_per_partition;
144+
}
145+
}
146+
}
147+
148+
inline void LogSparseCoreStats(
149+
const int32_t local_sc_id, const absl::string_view stacked_table_name,
150+
const MatrixXi& ids_per_sc_partition_per_bucket,
151+
const MatrixXi& unique_ids_per_partition_per_bucket, const size_t keys_size,
152+
const PartitionedCooTensors& grouped_coo_tensors) {
153+
if (VLOG_IS_ON(2)) {
154+
LOG(INFO) << "For table " << stacked_table_name << " on local SparseCore "
155+
<< local_sc_id
156+
<< ": Observed ids per global SparseCore partition: "
157+
<< ids_per_sc_partition_per_bucket.rowwise().sum();
158+
159+
LOG(INFO) << "For table " << stacked_table_name << " on local SparseCore "
160+
<< local_sc_id
161+
<< ": Observed unique ids per global SparseCore partition: "
162+
<< unique_ids_per_partition_per_bucket.rowwise().sum();
163+
164+
LOG(INFO) << "For table " << stacked_table_name << " on local SparseCore "
165+
<< local_sc_id << ": Total number of ids processed: " << keys_size
166+
<< ", total after deduplication: "
167+
<< ids_per_sc_partition_per_bucket.sum()
168+
<< ", total after drop id: "
169+
<< grouped_coo_tensors.Size(local_sc_id);
170+
}
171+
}
172+
110173
} // namespace internal
111174

112175
// Sorts and groups the provided COO tensors in this hierarchy: Local SC ->
@@ -206,42 +269,43 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
206269
// dedup the id by adding the gains.
207270
if (col_id == prev_col_id && row_id == prev_row_id) {
208271
grouped_coo_tensors.MergeWithLastCoo(coo_tensor);
209-
} else {
210-
const bool is_new_col =
211-
(bucket_id != prev_bucket_id || col_id != prev_col_id);
212-
// For stats, we need to count this ID if it is not a duplicate.
213-
ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) += 1;
214-
if (is_new_col) {
215-
unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) += 1;
216-
}
272+
continue;
273+
}
274+
275+
const bool is_new_col =
276+
(bucket_id != prev_bucket_id || col_id != prev_col_id);
277+
// For stats, we need to count this ID if it is not a duplicate.
278+
ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) += 1;
279+
if (is_new_col) {
280+
unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) += 1;
281+
}
217282

218-
// We do NOT drop IDs when minibatching is enabled and we are in the
219-
// first pass (`create_buckets=false`), as we need to detect limit
220-
// overflows to decide if minibatching is required. So, we only check if
221-
// limits would be exceeded in cases where we might drop an ID.
222-
bool would_exceed_limits = false;
223-
if (!options.enable_minibatching || create_buckets) {
224-
would_exceed_limits =
225-
(ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) >
226-
max_ids_per_partition) ||
227-
(is_new_col &&
228-
(unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) >
229-
max_unique_ids_per_partition));
230-
}
283+
// We do NOT drop IDs when minibatching is enabled and we are in the
284+
// first pass (`create_buckets=false`), as we need to detect limit
285+
// overflows to decide if minibatching is required.
286+
const bool can_drop_id = !options.enable_minibatching || create_buckets;
231287

232-
// If adding the ID would exceed limits and ID dropping is allowed, drop
233-
// it.
234-
if (would_exceed_limits && allow_id_dropping) {
235-
// Dropped id.
236-
++stats.dropped_id_count;
237-
} else {
238-
grouped_coo_tensors.Add(local_sc_id, bucket_id, coo_tensor);
239-
prev_col_id = col_id;
240-
prev_row_id = row_id;
241-
prev_bucket_id = bucket_id;
242-
}
288+
const bool exceeds_ids_limit =
289+
ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) >
290+
max_ids_per_partition;
291+
const bool exceeds_unique_ids_limit =
292+
is_new_col &&
293+
unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) >
294+
max_unique_ids_per_partition;
295+
const bool would_exceed_limits =
296+
exceeds_ids_limit || exceeds_unique_ids_limit;
297+
298+
// If ID dropping is allowed and limits would be exceeded, drop the ID.
299+
if (can_drop_id && would_exceed_limits && allow_id_dropping) {
300+
// Dropped id.
301+
++stats.dropped_id_count;
302+
} else {
303+
grouped_coo_tensors.Add(local_sc_id, bucket_id, coo_tensor);
304+
prev_col_id = col_id;
305+
prev_row_id = row_id;
306+
prev_bucket_id = bucket_id;
243307
}
244-
}
308+
} // end key loop
245309
grouped_coo_tensors.FillRemainingScBuckets();
246310

247311
// Update global max using this device's values.
@@ -258,56 +322,20 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
258322
})
259323
.sum();
260324

261-
if (VLOG_IS_ON(2)) {
262-
LOG(INFO) << "Observed ids per partition/sparsecore"
263-
<< " for table " << stacked_table_name << ": "
264-
<< ids_per_sc_partition_per_bucket.rowwise().sum();
265-
266-
LOG(INFO) << "Observed unique ids per partition/sparsecore"
267-
<< " for table " << stacked_table_name << ": "
268-
<< unique_ids_per_partition_per_bucket.rowwise().sum();
269-
270-
LOG(INFO) << "Total number of ids for table " << stacked_table_name
271-
<< " on SparseCore" << local_sc_id << ": " << keys.size()
272-
<< ", after deduplication: "
273-
<< ids_per_sc_partition_per_bucket.sum()
274-
<< ", after drop id: " << grouped_coo_tensors.Size(local_sc_id);
275-
}
325+
internal::LogSparseCoreStats(
326+
local_sc_id, stacked_table_name, ids_per_sc_partition_per_bucket,
327+
unique_ids_per_partition_per_bucket, keys.size(), grouped_coo_tensors);
276328

277329
const int32_t observed_max_ids_per_bucket =
278330
ids_per_sc_partition_per_bucket.maxCoeff();
279331
const int32_t observed_max_unique_ids_per_bucket =
280332
unique_ids_per_partition_per_bucket.maxCoeff();
281333

282334
if (options.enable_minibatching) {
283-
// This works both when minibatching is required and not. In the former
284-
// case we have bool which tells us if minibatching is required,
285-
// in the latter case it is std::bitset<64> which tells us the exact
286-
// splits.
287-
for (int global_sc_id = 0; global_sc_id < global_sc_count;
288-
++global_sc_id) {
289-
auto ids_per_bucket =
290-
ids_per_sc_partition_per_bucket.row(global_sc_id).array();
291-
auto unique_ids_per_bucket =
292-
unique_ids_per_partition_per_bucket.row(global_sc_id).array();
293-
if constexpr (std::is_same_v<SplitType, MinibatchingSplit>) {
294-
// absl::Makespan works because the array would be row-major and
295-
// values would be contiguous in memory.
296-
static_assert(decltype(ids_per_bucket)::IsRowMajor);
297-
static_assert(decltype(unique_ids_per_bucket)::IsRowMajor);
298-
// NOTE: ComputeSplit modifies the span, but we have already updated
299-
// the output stats.
300-
minibatching_split |= internal::ComputeMinibatchingSplit(
301-
absl::MakeSpan(ids_per_bucket), max_ids_per_partition);
302-
minibatching_split |= internal::ComputeMinibatchingSplit(
303-
absl::MakeSpan(unique_ids_per_bucket),
304-
max_unique_ids_per_partition);
305-
} else {
306-
minibatching_split |=
307-
unique_ids_per_bucket.maxCoeff() > max_unique_ids_per_partition ||
308-
ids_per_bucket.maxCoeff() > max_ids_per_partition;
309-
}
310-
}
335+
internal::UpdateMinibatchingSplit(
336+
ids_per_sc_partition_per_bucket, unique_ids_per_partition_per_bucket,
337+
global_sc_count, max_ids_per_partition, max_unique_ids_per_partition,
338+
minibatching_split);
311339
}
312340

313341
// Only validate if creating minibatching buckets or when minibatching is
@@ -317,7 +345,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
317345
observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
318346
max_ids_per_partition, max_unique_ids_per_partition,
319347
stacked_table_name, allow_id_dropping);
320-
}
348+
} // end local_sc_id loop
321349

322350
return grouped_coo_tensors;
323351
}

0 commit comments

Comments
 (0)