magpie/app/services/platforms/reddit_comment.py

141 lines
4.9 KiB
Python

from __future__ import annotations
import logging
import re
from datetime import date, timedelta
import httpx
from app.services.platforms.base import PostingStrategy, PostResult
from app.services.reddit.client import RedditClient
from app.core.config import get_settings
logger = logging.getLogger(__name__)
# Weekday names → int (0=Mon, 6=Sun)
_WEEKDAY_MAP = {
"monday": 0, "tuesday": 1, "wednesday": 2, "thursday": 3,
"friday": 4, "saturday": 5, "sunday": 6,
}
# Ordinal names → n
_ORDINAL_MAP = {"first": 1, "second": 2, "third": 3}
def is_nth_weekday(dt: date, weekday: int, n: int) -> bool:
"""True if dt is the nth occurrence of weekday (0=Mon, 6=Sun) in its month."""
first_of_month = dt.replace(day=1)
days_until = (weekday - first_of_month.weekday()) % 7
first_occurrence = first_of_month + timedelta(days=days_until)
nth_occurrence = first_occurrence + timedelta(weeks=n - 1)
return dt == nth_occurrence
def parse_occurrence(occurrence: str | None) -> tuple[int, int] | None:
"""Parse an occurrence string into (weekday, n) or None for 'every'.
Supported: "first_sunday", "second_monday", "third_friday", etc.
Returns None for "every" or None input.
Raises ValueError for unrecognised patterns.
"""
if occurrence is None or occurrence == "every":
return None
parts = occurrence.lower().split("_", 1)
if len(parts) != 2:
raise ValueError(f"Unrecognised occurrence format: {occurrence!r}")
ordinal, weekday_name = parts
if ordinal not in _ORDINAL_MAP or weekday_name not in _WEEKDAY_MAP:
raise ValueError(f"Unrecognised occurrence: {occurrence!r}")
return _WEEKDAY_MAP[weekday_name], _ORDINAL_MAP[ordinal]
def _extract_thread_id_from_url(url: str) -> str:
"""Extract a Reddit post ID from a full Reddit URL.
Expects URLs of the form:
https://www.reddit.com/r/<sub>/comments/<id>/<title>/
Raises ValueError if the ID cannot be found.
"""
match = re.search(r"/comments/([a-zA-Z0-9]+)/", url)
if not match:
raise ValueError(f"Cannot extract thread id from {url!r}")
return match.group(1)
def _find_sticky(
sub: str,
title_pattern: str,
session_file: str | None = None,
) -> str | None:
"""Search the hot listing of a subreddit for a post matching title_pattern.
Uses the Reddit public JSON API (no auth required).
Returns the post ID (e.g. "abc123") of the first match, or None.
"""
# TODO: use session_file for authenticated requests on private subs
url = f"https://www.reddit.com/r/{sub}/hot.json?limit=10"
try:
response = httpx.get(url, headers={"User-Agent": "magpie/1.0"}, timeout=10)
response.raise_for_status()
except httpx.HTTPError as exc:
logger.warning("Reddit hot.json request failed for r/%s: %s", sub, exc)
raise RuntimeError(f"Failed to fetch hot listing for r/{sub}") from exc
payload = response.json()
if "data" not in payload:
logger.warning("Unexpected Reddit API response: %r", payload)
return None
children = payload.get("data", {}).get("children", [])
pattern_lower = title_pattern.lower()
for child in children:
post = child.get("data", {})
title = post.get("title", "")
if pattern_lower in title.lower():
return post.get("id")
return None
class RedditCommentStrategy(PostingStrategy):
campaign_type = "reddit_comment"
def supports_dupe_guard(self) -> bool:
return False # comment threads may appear multiple times
def execute(
self,
*,
target: str,
title: str,
body: str,
flair: str | None = None,
extra: dict | None = None,
) -> PostResult:
extra = extra or {}
thread_url_override = extra.get("thread_url_override")
thread_title_pattern = extra.get("thread_title_pattern")
session_file = get_settings().reddit_session_file
# Resolve thread_id
if thread_url_override:
thread_id = _extract_thread_id_from_url(thread_url_override)
elif thread_title_pattern:
thread_id = _find_sticky(
sub=target,
title_pattern=thread_title_pattern,
session_file=session_file,
)
if thread_id is None:
raise ValueError(
f"No thread matching {thread_title_pattern!r} found in r/{target}"
)
else:
raise ValueError(
"RedditCommentStrategy requires thread_url_override or thread_title_pattern in extra"
)
client = RedditClient(session_file=session_file)
comment_url = client.comment(thread_id=thread_id, body=body)
# Reddit comment() may return empty URL; reconstruct from thread_id
if not comment_url:
comment_url = f"https://www.reddit.com/r/{target}/comments/{thread_id}/"
return PostResult(url=comment_url, metadata={"thread_id": thread_id})