feat: add thread detection helpers to reddit_comment strategy
This commit is contained in:
parent
ca9b2ac0b2
commit
9d955b2c50
2 changed files with 91 additions and 1 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -43,3 +44,39 @@ def parse_occurrence(occurrence: str | None) -> tuple[int, int] | None:
|
||||||
if ordinal not in _ORDINAL_MAP or weekday_name not in _WEEKDAY_MAP:
|
if ordinal not in _ORDINAL_MAP or weekday_name not in _WEEKDAY_MAP:
|
||||||
raise ValueError(f"Unrecognised occurrence: {occurrence!r}")
|
raise ValueError(f"Unrecognised occurrence: {occurrence!r}")
|
||||||
return _WEEKDAY_MAP[weekday_name], _ORDINAL_MAP[ordinal]
|
return _WEEKDAY_MAP[weekday_name], _ORDINAL_MAP[ordinal]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_thread_id_from_url(url: str) -> str:
|
||||||
|
"""Extract a Reddit post ID from a full Reddit URL.
|
||||||
|
|
||||||
|
Expects URLs of the form:
|
||||||
|
https://www.reddit.com/r/<sub>/comments/<id>/<title>/
|
||||||
|
Raises ValueError if the ID cannot be found.
|
||||||
|
"""
|
||||||
|
match = re.search(r"/comments/([a-z0-9]+)/", url)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Cannot extract thread id from {url!r}")
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_sticky(
|
||||||
|
sub: str,
|
||||||
|
title_pattern: str,
|
||||||
|
session_file: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Search the hot listing of a subreddit for a post matching title_pattern.
|
||||||
|
|
||||||
|
Uses the Reddit public JSON API (no auth required).
|
||||||
|
Returns the post ID (e.g. "abc123") of the first match, or None.
|
||||||
|
"""
|
||||||
|
url = f"https://www.reddit.com/r/{sub}/hot.json?limit=10"
|
||||||
|
response = httpx.get(url, headers={"User-Agent": "magpie/1.0"})
|
||||||
|
payload = response.json()
|
||||||
|
children = payload.get("data", {}).get("children", [])
|
||||||
|
pattern_lower = title_pattern.lower()
|
||||||
|
for child in children:
|
||||||
|
post = child.get("data", {})
|
||||||
|
title = post.get("title", "")
|
||||||
|
if pattern_lower in title.lower():
|
||||||
|
return post.get("id")
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,12 @@
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from app.services.platforms.reddit_comment import is_nth_weekday, parse_occurrence
|
from unittest.mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
from app.services.platforms.reddit_comment import (
|
||||||
|
is_nth_weekday,
|
||||||
|
parse_occurrence,
|
||||||
|
_extract_thread_id_from_url,
|
||||||
|
_find_sticky,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- is_nth_weekday ---
|
# --- is_nth_weekday ---
|
||||||
|
|
@ -49,3 +56,49 @@ def test_parse_occurrence_unknown_raises():
|
||||||
assert False, "Expected ValueError"
|
assert False, "Expected ValueError"
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# --- _extract_thread_id_from_url ---
|
||||||
|
|
||||||
|
def test_extract_thread_id_success():
|
||||||
|
url = "https://www.reddit.com/r/flipping/comments/abc123/weekly_thread/"
|
||||||
|
assert _extract_thread_id_from_url(url) == "abc123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_thread_id_invalid():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_extract_thread_id_from_url("https://www.reddit.com/r/flipping/")
|
||||||
|
|
||||||
|
|
||||||
|
# --- _find_sticky ---
|
||||||
|
|
||||||
|
def _make_hot_json(posts: list[dict]) -> dict:
|
||||||
|
"""Build a fake Reddit hot.json payload."""
|
||||||
|
return {"data": {"children": [{"data": p} for p in posts]}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_sticky_found(monkeypatch):
|
||||||
|
fake_response = MagicMock()
|
||||||
|
fake_response.json.return_value = _make_hot_json([
|
||||||
|
{"id": "abc123", "title": "Weekly Self-Promotion Thread"},
|
||||||
|
{"id": "xyz999", "title": "General Discussion"},
|
||||||
|
])
|
||||||
|
|
||||||
|
import app.services.platforms.reddit_comment as rc_module
|
||||||
|
monkeypatch.setattr(rc_module.httpx, "get", lambda *a, **kw: fake_response)
|
||||||
|
|
||||||
|
result = _find_sticky("flipping", "Self-Promotion")
|
||||||
|
assert result == "abc123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_sticky_not_found(monkeypatch):
|
||||||
|
fake_response = MagicMock()
|
||||||
|
fake_response.json.return_value = _make_hot_json([
|
||||||
|
{"id": "xyz999", "title": "General Discussion"},
|
||||||
|
])
|
||||||
|
|
||||||
|
import app.services.platforms.reddit_comment as rc_module
|
||||||
|
monkeypatch.setattr(rc_module.httpx, "get", lambda *a, **kw: fake_response)
|
||||||
|
|
||||||
|
result = _find_sticky("flipping", "Self-Promotion")
|
||||||
|
assert result is None
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue