From 9b96c45b6320c68323eb8af2c9ef08335641c051 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 15 Mar 2026 04:19:23 -0700 Subject: [PATCH] feat(scheduler): implement thread-safe singleton get_scheduler/reset_scheduler --- scripts/task_scheduler.py | 23 +++++++++++++---- tests/test_task_scheduler.py | 49 +++++++++++++++++++++++++++++++++--- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/scripts/task_scheduler.py b/scripts/task_scheduler.py index 574d020..307ba7f 100644 --- a/scripts/task_scheduler.py +++ b/scripts/task_scheduler.py @@ -196,12 +196,25 @@ _scheduler_lock = threading.Lock() def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler: """Return the process-level TaskScheduler singleton, constructing it if needed. - run_task_fn is required on the first call (when the singleton is constructed); - ignored on subsequent calls. Pass scripts.task_runner._run_task. + run_task_fn is required on the first call; ignored on subsequent calls. + Safety: inner lock + double-check prevents double-construction under races. + The outer None check is a fast-path performance optimisation only. """ - raise NotImplementedError + global _scheduler + if _scheduler is None: # fast path — avoids lock on steady state + with _scheduler_lock: + if _scheduler is None: # re-check under lock (double-checked locking) + if run_task_fn is None: + raise ValueError("run_task_fn required on first get_scheduler() call") + _scheduler = TaskScheduler(db_path, run_task_fn) + _scheduler.start() + return _scheduler def reset_scheduler() -> None: - """Shut down and clear the singleton. TEST TEARDOWN ONLY — not for production use.""" - raise NotImplementedError + """Shut down and clear the singleton. TEST TEARDOWN ONLY.""" + global _scheduler + with _scheduler_lock: + if _scheduler is not None: + _scheduler.shutdown() + _scheduler = None diff --git a/tests/test_task_scheduler.py b/tests/test_task_scheduler.py index f174c08..4128467 100644 --- a/tests/test_task_scheduler.py +++ b/tests/test_task_scheduler.py @@ -72,10 +72,7 @@ def _noop_run_task(*args, **kwargs): def clean_scheduler(): """Reset singleton between every test.""" yield - try: - reset_scheduler() - except NotImplementedError: - pass + reset_scheduler() def test_default_budgets_used_when_no_config(tmp_db): @@ -330,3 +327,47 @@ def test_worker_crash_releases_vram(tmp_db): # Second task still ran, VRAM was released assert 2 in log assert s._reserved_vram == 0.0 + + +def test_get_scheduler_returns_singleton(tmp_db): + """Multiple calls to get_scheduler() return the same instance.""" + s1 = get_scheduler(tmp_db, _noop_run_task) + s2 = get_scheduler(tmp_db, _noop_run_task) + assert s1 is s2 + + +def test_singleton_thread_safe(tmp_db): + """Concurrent get_scheduler() calls produce exactly one instance.""" + instances = [] + errors = [] + + def _get(): + try: + instances.append(get_scheduler(tmp_db, _noop_run_task)) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=_get) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert len(set(id(s) for s in instances)) == 1 # all the same object + + +def test_reset_scheduler_cleans_up(tmp_db): + """reset_scheduler() shuts down the scheduler; no threads linger.""" + s = get_scheduler(tmp_db, _noop_run_task) + thread = s._thread + assert thread.is_alive() + + reset_scheduler() + + thread.join(timeout=2.0) + assert not thread.is_alive() + + # After reset, get_scheduler creates a fresh instance + s2 = get_scheduler(tmp_db, _noop_run_task) + assert s2 is not s