feat(scheduler): implement thread-safe singleton get_scheduler/reset_scheduler
This commit is contained in:
parent
84ce68af46
commit
1d9020c99a
2 changed files with 63 additions and 9 deletions
|
|
@ -196,12 +196,25 @@ _scheduler_lock = threading.Lock()
|
||||||
def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler:
|
def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler:
|
||||||
"""Return the process-level TaskScheduler singleton, constructing it if needed.
|
"""Return the process-level TaskScheduler singleton, constructing it if needed.
|
||||||
|
|
||||||
run_task_fn is required on the first call (when the singleton is constructed);
|
run_task_fn is required on the first call; ignored on subsequent calls.
|
||||||
ignored on subsequent calls. Pass scripts.task_runner._run_task.
|
Safety: inner lock + double-check prevents double-construction under races.
|
||||||
|
The outer None check is a fast-path performance optimisation only.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
global _scheduler
|
||||||
|
if _scheduler is None: # fast path — avoids lock on steady state
|
||||||
|
with _scheduler_lock:
|
||||||
|
if _scheduler is None: # re-check under lock (double-checked locking)
|
||||||
|
if run_task_fn is None:
|
||||||
|
raise ValueError("run_task_fn required on first get_scheduler() call")
|
||||||
|
_scheduler = TaskScheduler(db_path, run_task_fn)
|
||||||
|
_scheduler.start()
|
||||||
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
def reset_scheduler() -> None:
|
def reset_scheduler() -> None:
|
||||||
"""Shut down and clear the singleton. TEST TEARDOWN ONLY — not for production use."""
|
"""Shut down and clear the singleton. TEST TEARDOWN ONLY."""
|
||||||
raise NotImplementedError
|
global _scheduler
|
||||||
|
with _scheduler_lock:
|
||||||
|
if _scheduler is not None:
|
||||||
|
_scheduler.shutdown()
|
||||||
|
_scheduler = None
|
||||||
|
|
|
||||||
|
|
@ -72,10 +72,7 @@ def _noop_run_task(*args, **kwargs):
|
||||||
def clean_scheduler():
|
def clean_scheduler():
|
||||||
"""Reset singleton between every test."""
|
"""Reset singleton between every test."""
|
||||||
yield
|
yield
|
||||||
try:
|
|
||||||
reset_scheduler()
|
reset_scheduler()
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_budgets_used_when_no_config(tmp_db):
|
def test_default_budgets_used_when_no_config(tmp_db):
|
||||||
|
|
@ -330,3 +327,47 @@ def test_worker_crash_releases_vram(tmp_db):
|
||||||
# Second task still ran, VRAM was released
|
# Second task still ran, VRAM was released
|
||||||
assert 2 in log
|
assert 2 in log
|
||||||
assert s._reserved_vram == 0.0
|
assert s._reserved_vram == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_scheduler_returns_singleton(tmp_db):
|
||||||
|
"""Multiple calls to get_scheduler() return the same instance."""
|
||||||
|
s1 = get_scheduler(tmp_db, _noop_run_task)
|
||||||
|
s2 = get_scheduler(tmp_db, _noop_run_task)
|
||||||
|
assert s1 is s2
|
||||||
|
|
||||||
|
|
||||||
|
def test_singleton_thread_safe(tmp_db):
|
||||||
|
"""Concurrent get_scheduler() calls produce exactly one instance."""
|
||||||
|
instances = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def _get():
|
||||||
|
try:
|
||||||
|
instances.append(get_scheduler(tmp_db, _noop_run_task))
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=_get) for _ in range(20)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert not errors
|
||||||
|
assert len(set(id(s) for s in instances)) == 1 # all the same object
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_scheduler_cleans_up(tmp_db):
|
||||||
|
"""reset_scheduler() shuts down the scheduler; no threads linger."""
|
||||||
|
s = get_scheduler(tmp_db, _noop_run_task)
|
||||||
|
thread = s._thread
|
||||||
|
assert thread.is_alive()
|
||||||
|
|
||||||
|
reset_scheduler()
|
||||||
|
|
||||||
|
thread.join(timeout=2.0)
|
||||||
|
assert not thread.is_alive()
|
||||||
|
|
||||||
|
# After reset, get_scheduler creates a fresh instance
|
||||||
|
s2 = get_scheduler(tmp_db, _noop_run_task)
|
||||||
|
assert s2 is not s
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue