diff --git a/app/services/platforms/__init__.py b/app/services/platforms/__init__.py index a987cec..d9ae1d3 100644 --- a/app/services/platforms/__init__.py +++ b/app/services/platforms/__init__.py @@ -2,12 +2,13 @@ from __future__ import annotations from app.services.platforms.base import PostingStrategy, PostResult from app.services.platforms.reddit_post import RedditPostStrategy +from app.services.platforms.reddit_comment import RedditCommentStrategy _REGISTRY: dict[str, PostingStrategy] = { s.campaign_type: s() for s in [ RedditPostStrategy, - # RedditCommentStrategy — added in Plan B + RedditCommentStrategy, # BlogPostStrategy — added in Plan C ] } diff --git a/app/services/platforms/reddit_comment.py b/app/services/platforms/reddit_comment.py index ec6f2e3..7461a6c 100644 --- a/app/services/platforms/reddit_comment.py +++ b/app/services/platforms/reddit_comment.py @@ -7,6 +7,8 @@ from datetime import date, timedelta import httpx from app.services.platforms.base import PostingStrategy, PostResult +from app.services.reddit.client import RedditClient +from app.core.config import get_settings logger = logging.getLogger(__name__) @@ -85,3 +87,51 @@ def _find_sticky( if pattern_lower in title.lower(): return post.get("id") return None + + +class RedditCommentStrategy(PostingStrategy): + campaign_type = "reddit_comment" + + def supports_dupe_guard(self) -> bool: + return False # comment threads may appear multiple times + + def execute( + self, + *, + target: str, + title: str, + body: str, + flair: str | None = None, + extra: dict | None = None, + ) -> PostResult: + extra = extra or {} + thread_url_override = extra.get("thread_url_override") + thread_title_pattern = extra.get("thread_title_pattern") + session_file = get_settings().reddit_session_file + + # Resolve thread_id + if thread_url_override: + thread_id = _extract_thread_id_from_url(thread_url_override) + elif thread_title_pattern: + thread_id = _find_sticky( + sub=target, + title_pattern=thread_title_pattern, + session_file=session_file, + ) + if thread_id is None: + raise ValueError( + f"No thread matching {thread_title_pattern!r} found in r/{target}" + ) + else: + raise ValueError( + "RedditCommentStrategy requires thread_url_override or thread_title_pattern in extra" + ) + + client = RedditClient(session_file=session_file) + comment_url = client.comment(thread_id=thread_id, body=body) + + # Reddit comment() may return empty URL; reconstruct from thread_id + if not comment_url: + comment_url = f"https://www.reddit.com/r/{target}/comments/{thread_id}/" + + return PostResult(url=comment_url, metadata={"thread_id": thread_id}) diff --git a/tests/services/platforms/test_reddit_comment_strategy.py b/tests/services/platforms/test_reddit_comment_strategy.py new file mode 100644 index 0000000..9bc7513 --- /dev/null +++ b/tests/services/platforms/test_reddit_comment_strategy.py @@ -0,0 +1,90 @@ +from unittest.mock import patch, MagicMock +import pytest +from app.services.platforms.reddit_comment import RedditCommentStrategy +from app.services.platforms.base import PostResult + + +@pytest.fixture +def strategy(): + return RedditCommentStrategy() + + +def test_execute_with_url_override(strategy, monkeypatch): + """Uses thread_url_override to get thread_id, calls client.comment()""" + mock_client = MagicMock() + mock_client.comment.return_value = "https://www.reddit.com/r/flipping/comments/abc123/_/xyz789/" + with patch("app.services.platforms.reddit_comment.RedditClient", return_value=mock_client): + with patch("app.services.platforms.reddit_comment.get_settings") as mock_settings: + mock_settings.return_value.reddit_session_file = "/fake/session.json" + result = strategy.execute( + target="flipping", + title="ignored", + body="Hello thread!", + extra={"thread_url_override": "https://www.reddit.com/r/flipping/comments/abc123/weekly/"}, + ) + assert result.url == "https://www.reddit.com/r/flipping/comments/abc123/_/xyz789/" + assert result.metadata["thread_id"] == "abc123" + mock_client.comment.assert_called_once_with(thread_id="abc123", body="Hello thread!") + + +def test_execute_with_title_pattern_found(strategy): + """Uses _find_sticky to locate thread, posts comment""" + with patch("app.services.platforms.reddit_comment._find_sticky", return_value="def456"): + with patch("app.services.platforms.reddit_comment.RedditClient") as MockClient: + with patch("app.services.platforms.reddit_comment.get_settings") as mock_settings: + mock_settings.return_value.reddit_session_file = "/fake/session.json" + MockClient.return_value.comment.return_value = "" + result = strategy.execute( + target="cscareerquestions", + title="ignored", + body="Job search tool", + extra={"thread_title_pattern": "Monthly Resume"}, + ) + assert "def456" in result.url + assert result.metadata["thread_id"] == "def456" + + +def test_execute_thread_not_found(strategy): + """Raises ValueError when _find_sticky returns None""" + with patch("app.services.platforms.reddit_comment._find_sticky", return_value=None): + with patch("app.services.platforms.reddit_comment.get_settings") as mock_settings: + mock_settings.return_value.reddit_session_file = "/fake/session.json" + with pytest.raises(ValueError, match="No thread matching"): + strategy.execute( + target="cscareerquestions", + title="ignored", + body="body", + extra={"thread_title_pattern": "Monthly Resume"}, + ) + + +def test_execute_no_extra_raises(strategy): + """Raises ValueError when neither thread_url_override nor thread_title_pattern provided""" + with patch("app.services.platforms.reddit_comment.get_settings") as mock_settings: + mock_settings.return_value.reddit_session_file = "/fake/session.json" + with pytest.raises(ValueError, match="requires thread_url_override or thread_title_pattern"): + strategy.execute(target="flipping", title="t", body="b", extra={}) + + +def test_reconstructed_url_on_empty_comment_url(strategy): + """When client.comment() returns empty string, reconstructs URL from thread_id""" + with patch("app.services.platforms.reddit_comment.RedditClient") as MockClient: + with patch("app.services.platforms.reddit_comment.get_settings") as mock_settings: + mock_settings.return_value.reddit_session_file = "/fake/session.json" + MockClient.return_value.comment.return_value = "" + result = strategy.execute( + target="flipping", + title="t", + body="b", + extra={"thread_url_override": "https://www.reddit.com/r/flipping/comments/abc123/weekly/"}, + ) + assert result.url == "https://www.reddit.com/r/flipping/comments/abc123/" + + +def test_supports_dupe_guard_false(strategy): + assert strategy.supports_dupe_guard() is False + + +def test_registry_contains_reddit_comment(): + from app.services.platforms import _REGISTRY + assert "reddit_comment" in _REGISTRY