Skip to content

Commit 21a7c86

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Simplify how device events are handled. Because definition_stream() is always
provided and only used if the stream is defined, we can avoid tracking usage_stream separately. PiperOrigin-RevId: 814765407
1 parent 204266d commit 21a7c86

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

xla/pjrt/pjrt_stream_executor_client.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,7 +1420,9 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
14201420
for (const auto& stream_and_event : events) {
14211421
VLOG(4)
14221422
<< "Checking whether need to wait for stream_and_event: stream: "
1423-
<< stream_and_event.stream
1423+
<< (stream_and_event.event->IsDefined()
1424+
? stream_and_event.event->definition_stream()
1425+
: nullptr)
14241426
<< "; event: " << &*stream_and_event.event
14251427
<< "; reference_held: " << stream_and_event.reference_held
14261428
<< "; is_predetermined_error: "
@@ -1494,7 +1496,7 @@ void PjRtStreamExecutorBuffer::ConvertUsageHold(TrackedDeviceBuffer* buffer,
14941496
bool reference_held) {
14951497
absl::MutexLock lock(&mu_);
14961498
CHECK(device_buffer() == buffer || device_buffer() == nullptr);
1497-
buffer->AddUsageEvent(usage_stream, std::move(event), reference_held);
1499+
buffer->AddUsageEvent(std::move(event), reference_held);
14981500
DecrementUsage();
14991501
}
15001502

xla/pjrt/tracked_device_buffer.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "absl/status/status.h"
3333
#include "absl/synchronization/mutex.h"
3434
#include "absl/types/span.h"
35+
#include "xla/pjrt/device_event.h"
3536
#include "xla/pjrt/event_pool.h"
3637
#include "xla/pjrt/pjrt_client.h"
3738
#include "xla/pjrt/pjrt_common.h"
@@ -257,31 +258,31 @@ void TrackedDeviceBuffer::ConfirmDonation() {
257258
ReleaseDeviceMemory();
258259
}
259260

260-
void TrackedDeviceBuffer::AddUsageEvent(se::Stream* usage_stream,
261-
BufferSequencingEventRef event,
261+
void TrackedDeviceBuffer::AddUsageEvent(BufferSequencingEventRef event,
262262
bool reference_held) {
263263
CHECK(in_use_);
264264

265265
// If the event is 0, it means that the event is not recorded yet and the task
266266
// related to this event is deferred, so just add it.
267267
if (!event->IsDefined()) {
268-
usage_events_.push_back({usage_stream, event, reference_held});
268+
usage_events_.push_back({event, reference_held});
269269
return;
270270
}
271+
auto* usage_stream = event->definition_stream();
271272

272273
for (auto& existing : usage_events_) {
273274
// If the existing event is 0, it means that the event is not recorded yet
274275
// and the task related to this event is deferred, so don't replace it.
275276
if (!existing.event->IsDefined()) continue;
276-
if (existing.stream == usage_stream) {
277+
if (existing.event->definition_stream() == usage_stream) {
277278
if (*existing.event < *event) {
278279
existing.event = event;
279280
existing.reference_held = reference_held;
280281
}
281282
return;
282283
}
283284
}
284-
usage_events_.push_back({usage_stream, event, reference_held});
285+
usage_events_.push_back({event, reference_held});
285286
}
286287

287288
TrackedDeviceBuffer::StreamAndEventContainer
@@ -312,6 +313,16 @@ tsl::RCReference<CommonPjRtRawBuffer> TrackedDeviceBuffer::GetRawBuffer(
312313
device_memory_);
313314
}
314315

316+
void TrackedDeviceBuffer::AddUsageEvent(
317+
tsl::RCReference<PjRtDeviceEvent> event) {
318+
if (event) {
319+
AddUsageEvent(
320+
tensorflow::down_cast<PjRtStreamExecutorDeviceEvent*>(event.get())
321+
->event(),
322+
true);
323+
}
324+
}
325+
315326
void GetDeviceBufferEvents(
316327
const TrackedDeviceBuffer& buffer, bool get_usage_events,
317328
absl::flat_hash_set<BufferSequencingEvent*>* events) {

xla/pjrt/tracked_device_buffer.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ class BufferSequencingEvent : tsl::AsyncPayload::KeepOnError {
163163
se::Stream* definition_stream;
164164
};
165165

166+
se::Stream* definition_stream() const { return event_->definition_stream; }
167+
166168
private:
167169
uint64_t sequence_number() const;
168170

@@ -218,8 +220,6 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
218220
public:
219221
// Helper object to keep track of usage of the buffer on streams.
220222
struct StreamAndEvent {
221-
// A stream the buffer has been used on.
222-
se::Stream* stream;
223223
// An event that is later than the most recent usage of the buffer on
224224
// stream.
225225
BufferSequencingEventRef event;
@@ -282,8 +282,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
282282
// reference to *this to stay live until after the host
283283
// is sure that the usage (transfer or execution) has
284284
// completed.
285-
void AddUsageEvent(se::Stream* usage_stream, BufferSequencingEventRef event,
286-
bool reference_held);
285+
void AddUsageEvent(BufferSequencingEventRef event, bool reference_held);
287286

288287
using StreamAndEventContainer = absl::InlinedVector<StreamAndEvent, 3>;
289288
// Returns the set of streams that the buffer was used on, and for each stream
@@ -303,9 +302,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer {
303302
tsl::RCReference<CommonPjRtRawBuffer> GetRawBuffer(
304303
PjRtMemorySpace* memory_space) override;
305304

306-
void AddUsageEvent(tsl::RCReference<PjRtDeviceEvent> event) override {
307-
LOG(FATAL) << "Implement";
308-
}
305+
void AddUsageEvent(tsl::RCReference<PjRtDeviceEvent> event) override;
309306

310307
void Delete(PjRtMemorySpace* memory_space) override {
311308
LOG(FATAL) << "Implement";

0 commit comments

Comments
 (0)