Implement cache for hipStream in ROCm executor#869
Conversation
b11b315 to
14eee1d
Compare
Review SummaryThis PR adds a process-level Key finding: The use-after-free fix via See inline comments for details. |
ebfaf2c to
91d265d
Compare
hipStreamCreate on ROCm is expensive (~100 ms per stream). When a
PjRtClient is destroyed and a new one is immediately created (common in
tests and interactive use), all ~18 streams per device are destroyed and
recreated, blocking for several seconds.
This commit implements a process-level HipStreamHandleCache singleton
directly in rocm_stream.cc (ROCm-only, touches no CUDA/SYCL code).
Cache key: (device_ordinal, creation_flags, creation_priority_int).
On destruction (RocmStream::~RocmStream):
1. BlockHostUntilDone() already ran -- stream is idle.
2. hipStreamQuery() confirms idleness; on error the handle is
destroyed rather than cached (no poisoning).
3. hipStreamGetFlags / hipStreamGetPriority are called to build the
exact cache key, ensuring a retrieved handle always matches the
flags and priority the new stream would have used -- even if XLA
later switches to hipStreamNonBlocking.
4. Idle handle is stored; hipStreamDestroy is skipped.
On creation (RocmStream::Create via CreateStream):
The cache is checked first; on hit the cached handle is returned
directly and hipStreamCreate is skipped. On miss the cold path
calls hipStreamCreate as before.
The LocalDeviceState and RocmStream wrapper objects are still created
and destroyed normally on every client instantiation. DNN state is
cleaned up via DeallocateStream as usual. Only the underlying HIP
queue (hipStream_t) is reused.
Also fix a latent use-after-free in LocalDeviceState::~LocalDeviceState:
C++ destroys members in reverse declaration order. compute_events_
(line 352 in local_device_state.h) is declared after callback_thread_
(line 342), so its destructor runs *before* callback_thread_'s
destructor joins the worker thread. If callback_thread_ still has
pending pop_front(compute_events_) closures when compute_events_ is
destroyed, those closures access freed memory.
The fix adds callback_thread_->Drain() between SynchronizeAllActivity()
and the explicit stream/event clears. After Drain() the callback thread
is idle and compute_events_ can be safely cleared.
36c68c2 to
c07395d
Compare
There was a problem hiding this comment.
Just for the reference: this is previous PR #861
I assume this cache for hipStream will be "beneficial" not only to that iota unit test, but also to all hip stream related operations in XLA? it might be improving general e2e workloads as well?
Yes, this cache should reduce the execution time of all tests comprising more than one subtest. |
|
@mfrancepillois , I am just trying to understand: the current issue is that local_device_state creates about 18 stream per device? 1 compute or there are other places where streams get created/destroyed? ah I see, this is because each subtest creates/destroyes LocalDeviceState as part of PJRT client |
It caches the streams in multiple PJRT clients, not the streams within a single PJRT client. |
ook I see.. each subtest now is a full-fledged PJRT client which makes it quite heavy. So, basically, this will only help us with test execution, it won't affect real workload performance.. Possibly, we could also make stream creation in local_device_state lazy? I mean, we probably use just few streams per subtest out of those 18? |
Yes, this only beneficial for tests. |
ah I see.. yes, the optimal performance is normally achieved with 4 hardware queues on ROCM. So, we can reduce the number of threads for ROCM: instead of creating 18 streams, we can just create 1 of each kind, e.g.: If I am not mistaken, currently we create them in the following order: but this needs to be tested of course |
It is worse. We have 8 of them 4 by priority. We need to see if we should disable stream priority. Sorry I haven't got around to check this. I understand that queue creation is somewhat expensive, but once you create all 8 of them they just get resutes so the stream should be lightweight. |
Motivation
The iota_test was very slow on AMD targets (compared to NVDIA) because the pjrt client was destroyed and recreated for each of the 4500 tests that make up the
iota_test. This task in ROCm is ~40× slower than with CUDA (see table below).The main cause of slowdowns when creating and destroying a pjrt client lies in the creation and destruction of streams.
This PR implements a process-level
HipStreamHandleCachesingleton directly in rocm_stream.cc. Cache key: (device_ordinal, creation_flags, creation_priority_int).On destruction (RocmStream::~RocmStream):
On creation (RocmStream::Create via CreateStream):
The cache is checked first; on hit the cached handle is returned
directly and hipStreamCreate is skipped. On miss the cold path
calls hipStreamCreate as before.
The LocalDeviceState and RocmStream wrapper objects are still created and destroyed normally on every client instantiation. DNN state is cleaned up via DeallocateStream as usual. Only the underlying HIP queue (hipStream_t) is reused.
Also fix a latent use-after-free in
LocalDeviceState::~LocalDeviceState:C++ destroys members in reverse declaration order.
compute_events_(line 352 in local_device_state.h) is declared after
callback_thread_(line 342), so its destructor runs before
callback_thread_'sdestructor joins the worker thread. If callback_thread_ still has
pending
pop_front(compute_events_)closures whencompute_events_isdestroyed, those closures access freed memory.
The fix adds
callback_thread_->Drain()betweenSynchronizeAllActivity()and the explicit stream/event clears. After
Drain()the callback threadis idle and
compute_events_can be safely cleared.