@@ -107,6 +107,69 @@ inline void UpdateMaxIdsPerPartition(BlockRow<int>& global_max,
107107 global_max.cwiseMax (local_max_per_bucket.rowwise ().sum ().transpose ());
108108}
109109
110+ template <typename SplitType>
111+ inline void UpdateMinibatchingSplit (
112+ MatrixXi& ids_per_sc_partition_per_bucket,
113+ MatrixXi& unique_ids_per_partition_per_bucket,
114+ const int32_t global_sc_count, const int32_t max_ids_per_partition,
115+ const int32_t max_unique_ids_per_partition, SplitType& minibatching_split) {
116+ // This works both when minibatching is required and not. In the former
117+ // case we have bool which tells us if minibatching is required,
118+ // in the latter case it is std::bitset<64> which tells us the exact
119+ // splits.
120+ for (int global_sc_id = 0 ; global_sc_id < global_sc_count; ++global_sc_id) {
121+ auto ids_per_bucket =
122+ ids_per_sc_partition_per_bucket.row (global_sc_id).array ();
123+ auto unique_ids_per_bucket =
124+ unique_ids_per_partition_per_bucket.row (global_sc_id).array ();
125+ if constexpr (std::is_same_v<SplitType, MinibatchingSplit>) {
126+ // The arrays must be mutable as ComputeMinibatchingSplit modifies them.
127+ // absl::Makespan works because the array would be row-major and
128+ // values would be contiguous in memory.
129+ static_assert (decltype (ids_per_bucket)::IsRowMajor);
130+ static_assert (decltype (unique_ids_per_bucket)::IsRowMajor);
131+ // NOTE: ComputeSplit modifies the span, but we have already updated
132+ // the output stats.
133+ minibatching_split |= ComputeMinibatchingSplit (
134+ absl::MakeSpan (ids_per_bucket.data (), ids_per_bucket.size ()),
135+ max_ids_per_partition);
136+ minibatching_split |=
137+ ComputeMinibatchingSplit (absl::MakeSpan (unique_ids_per_bucket.data (),
138+ unique_ids_per_bucket.size ()),
139+ max_unique_ids_per_partition);
140+ } else {
141+ minibatching_split |=
142+ unique_ids_per_bucket.maxCoeff () > max_unique_ids_per_partition ||
143+ ids_per_bucket.maxCoeff () > max_ids_per_partition;
144+ }
145+ }
146+ }
147+
148+ inline void LogSparseCoreStats (
149+ const int32_t local_sc_id, const absl::string_view stacked_table_name,
150+ const MatrixXi& ids_per_sc_partition_per_bucket,
151+ const MatrixXi& unique_ids_per_partition_per_bucket, const size_t keys_size,
152+ const PartitionedCooTensors& grouped_coo_tensors) {
153+ if (VLOG_IS_ON (2 )) {
154+ LOG (INFO) << " For table " << stacked_table_name << " on local SparseCore "
155+ << local_sc_id
156+ << " : Observed ids per global SparseCore partition: "
157+ << ids_per_sc_partition_per_bucket.rowwise ().sum ();
158+
159+ LOG (INFO) << " For table " << stacked_table_name << " on local SparseCore "
160+ << local_sc_id
161+ << " : Observed unique ids per global SparseCore partition: "
162+ << unique_ids_per_partition_per_bucket.rowwise ().sum ();
163+
164+ LOG (INFO) << " For table " << stacked_table_name << " on local SparseCore "
165+ << local_sc_id << " : Total number of ids processed: " << keys_size
166+ << " , total after deduplication: "
167+ << ids_per_sc_partition_per_bucket.sum ()
168+ << " , total after drop id: "
169+ << grouped_coo_tensors.Size (local_sc_id);
170+ }
171+ }
172+
110173} // namespace internal
111174
112175// Sorts and groups the provided COO tensors in this hierarchy: Local SC ->
@@ -206,42 +269,43 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
206269 // dedup the id by adding the gains.
207270 if (col_id == prev_col_id && row_id == prev_row_id) {
208271 grouped_coo_tensors.MergeWithLastCoo (coo_tensor);
209- } else {
210- const bool is_new_col =
211- (bucket_id != prev_bucket_id || col_id != prev_col_id);
212- // For stats, we need to count this ID if it is not a duplicate.
213- ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
214- if (is_new_col) {
215- unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
216- }
272+ continue ;
273+ }
274+
275+ const bool is_new_col =
276+ (bucket_id != prev_bucket_id || col_id != prev_col_id);
277+ // For stats, we need to count this ID if it is not a duplicate.
278+ ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
279+ if (is_new_col) {
280+ unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
281+ }
217282
218- // We do NOT drop IDs when minibatching is enabled and we are in the
219- // first pass (`create_buckets=false`), as we need to detect limit
220- // overflows to decide if minibatching is required. So, we only check if
221- // limits would be exceeded in cases where we might drop an ID.
222- bool would_exceed_limits = false ;
223- if (!options.enable_minibatching || create_buckets) {
224- would_exceed_limits =
225- (ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) >
226- max_ids_per_partition) ||
227- (is_new_col &&
228- (unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) >
229- max_unique_ids_per_partition));
230- }
283+ // We do NOT drop IDs when minibatching is enabled and we are in the
284+ // first pass (`create_buckets=false`), as we need to detect limit
285+ // overflows to decide if minibatching is required.
286+ const bool can_drop_id = !options.enable_minibatching || create_buckets;
231287
232- // If adding the ID would exceed limits and ID dropping is allowed, drop
233- // it.
234- if (would_exceed_limits && allow_id_dropping) {
235- // Dropped id.
236- ++stats.dropped_id_count ;
237- } else {
238- grouped_coo_tensors.Add (local_sc_id, bucket_id, coo_tensor);
239- prev_col_id = col_id;
240- prev_row_id = row_id;
241- prev_bucket_id = bucket_id;
242- }
288+ const bool exceeds_ids_limit =
289+ ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) >
290+ max_ids_per_partition;
291+ const bool exceeds_unique_ids_limit =
292+ is_new_col &&
293+ unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) >
294+ max_unique_ids_per_partition;
295+ const bool would_exceed_limits =
296+ exceeds_ids_limit || exceeds_unique_ids_limit;
297+
298+ // If ID dropping is allowed and limits would be exceeded, drop the ID.
299+ if (can_drop_id && would_exceed_limits && allow_id_dropping) {
300+ // Dropped id.
301+ ++stats.dropped_id_count ;
302+ } else {
303+ grouped_coo_tensors.Add (local_sc_id, bucket_id, coo_tensor);
304+ prev_col_id = col_id;
305+ prev_row_id = row_id;
306+ prev_bucket_id = bucket_id;
243307 }
244- }
308+ } // end key loop
245309 grouped_coo_tensors.FillRemainingScBuckets ();
246310
247311 // Update global max using this device's values.
@@ -258,56 +322,20 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
258322 })
259323 .sum ();
260324
261- if (VLOG_IS_ON (2 )) {
262- LOG (INFO) << " Observed ids per partition/sparsecore"
263- << " for table " << stacked_table_name << " : "
264- << ids_per_sc_partition_per_bucket.rowwise ().sum ();
265-
266- LOG (INFO) << " Observed unique ids per partition/sparsecore"
267- << " for table " << stacked_table_name << " : "
268- << unique_ids_per_partition_per_bucket.rowwise ().sum ();
269-
270- LOG (INFO) << " Total number of ids for table " << stacked_table_name
271- << " on SparseCore" << local_sc_id << " : " << keys.size ()
272- << " , after deduplication: "
273- << ids_per_sc_partition_per_bucket.sum ()
274- << " , after drop id: " << grouped_coo_tensors.Size (local_sc_id);
275- }
325+ internal::LogSparseCoreStats (
326+ local_sc_id, stacked_table_name, ids_per_sc_partition_per_bucket,
327+ unique_ids_per_partition_per_bucket, keys.size (), grouped_coo_tensors);
276328
277329 const int32_t observed_max_ids_per_bucket =
278330 ids_per_sc_partition_per_bucket.maxCoeff ();
279331 const int32_t observed_max_unique_ids_per_bucket =
280332 unique_ids_per_partition_per_bucket.maxCoeff ();
281333
282334 if (options.enable_minibatching ) {
283- // This works both when minibatching is required and not. In the former
284- // case we have bool which tells us if minibatching is required,
285- // in the latter case it is std::bitset<64> which tells us the exact
286- // splits.
287- for (int global_sc_id = 0 ; global_sc_id < global_sc_count;
288- ++global_sc_id) {
289- auto ids_per_bucket =
290- ids_per_sc_partition_per_bucket.row (global_sc_id).array ();
291- auto unique_ids_per_bucket =
292- unique_ids_per_partition_per_bucket.row (global_sc_id).array ();
293- if constexpr (std::is_same_v<SplitType, MinibatchingSplit>) {
294- // absl::Makespan works because the array would be row-major and
295- // values would be contiguous in memory.
296- static_assert (decltype (ids_per_bucket)::IsRowMajor);
297- static_assert (decltype (unique_ids_per_bucket)::IsRowMajor);
298- // NOTE: ComputeSplit modifies the span, but we have already updated
299- // the output stats.
300- minibatching_split |= internal::ComputeMinibatchingSplit (
301- absl::MakeSpan (ids_per_bucket), max_ids_per_partition);
302- minibatching_split |= internal::ComputeMinibatchingSplit (
303- absl::MakeSpan (unique_ids_per_bucket),
304- max_unique_ids_per_partition);
305- } else {
306- minibatching_split |=
307- unique_ids_per_bucket.maxCoeff () > max_unique_ids_per_partition ||
308- ids_per_bucket.maxCoeff () > max_ids_per_partition;
309- }
310- }
335+ internal::UpdateMinibatchingSplit (
336+ ids_per_sc_partition_per_bucket, unique_ids_per_partition_per_bucket,
337+ global_sc_count, max_ids_per_partition, max_unique_ids_per_partition,
338+ minibatching_split);
311339 }
312340
313341 // Only validate if creating minibatching buckets or when minibatching is
@@ -317,7 +345,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
317345 observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
318346 max_ids_per_partition, max_unique_ids_per_partition,
319347 stacked_table_name, allow_id_dropping);
320- }
348+ } // end local_sc_id loop
321349
322350 return grouped_coo_tensors;
323351}
0 commit comments