fix(scheduler): join batch worker threads in shutdown()

Previously shutdown() only joined the scheduler loop thread. Batch
worker threads (which decrement _reserved_vram in their finally block)
could still be running when shutdown returned, leaving stale VRAM
accounting. Now snapshots active workers under lock and joins them all.

Snapshot-then-join pattern avoids holding the lock across blocking join
calls (which would deadlock since workers acquire the same lock on exit).
This commit is contained in:
pyr0ball 2026-04-01 11:21:30 -07:00
parent 6b8e421eb2
commit aa51794f45

View file

@ -203,11 +203,21 @@ class TaskScheduler:
self._wake.set() self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None: def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit.""" """Signal the scheduler to stop and wait for it to exit.
Joins both the scheduler loop thread and any active batch worker
threads so callers can rely on clean state (e.g. _reserved_vram == 0)
immediately after this returns.
"""
self._stop.set() self._stop.set()
self._wake.set() self._wake.set()
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout) self._thread.join(timeout=timeout)
# Join active batch workers so _reserved_vram is settled on return
with self._lock:
workers = list(self._active.values())
for worker in workers:
worker.join(timeout=timeout)
def _scheduler_loop(self) -> None: def _scheduler_loop(self) -> None:
while not self._stop.is_set(): while not self._stop.is_set():