diff --git a/pyproject.toml b/pyproject.toml index a3b8cab..5f75da9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "google-auth-oauthlib>=1.2.0", "sentry-sdk[fastapi]>=2.0.0", "prometheus-client>=0.20.0", + "python-multipart>=0.0.9", ] [project.optional-dependencies] diff --git a/src/api/routes/memory.py b/src/api/routes/memory.py index 3397a69..bc7c910 100644 --- a/src/api/routes/memory.py +++ b/src/api/routes/memory.py @@ -8,6 +8,7 @@ import asyncio import logging +import threading import time from typing import Any, Dict, List @@ -119,6 +120,153 @@ def _detect_chat_provider(url: str) -> str: return "unknown" +def _clean_message_text(value: Any) -> str: + if isinstance(value, str): + return " ".join(value.split()) + if isinstance(value, list): + parts = [_clean_message_text(item) for item in value] + return "\n".join(part for part in parts if part).strip() + if isinstance(value, dict): + for key in ("text", "content", "value", "markdown", "plainText", "body"): + text = _clean_message_text(value.get(key)) + if text: + return text + return "" + + +def _role_from_message(message: Dict[str, Any]) -> str | None: + role_value = None + for key in ("sender", "role", "author", "type"): + if key in message: + role_value = message[key] + break + + if isinstance(role_value, dict): + role_value = role_value.get("role") or role_value.get("name") + if not isinstance(role_value, str): + return None + + lowered = role_value.lower() + if lowered in {"human", "user"}: + return "user" + if lowered in {"assistant", "model", "claude"}: + return "assistant" + return None + + +def _pairs_from_message_list(messages: List[Any]) -> List[MessagePair]: + pairs: List[MessagePair] = [] + current_user = "" + + for item in messages: + if not isinstance(item, dict): + continue + + role = _role_from_message(item) + text = _clean_message_text(item) + if not role or not text: + continue + + if role == "user": + current_user = text + elif role == "assistant" and current_user: + pairs.append(MessagePair( + user_query=current_user, + agent_response=text, + )) + current_user = "" + + return pairs + + +def _extract_structured_pairs(data: Any) -> List[MessagePair]: + if isinstance(data, dict): + messages = data.get("messages") + if isinstance(messages, list): + pairs = _pairs_from_message_list(messages) + if pairs: + return pairs + + for value in data.values(): + pairs = _extract_structured_pairs(value) + if pairs: + return pairs + + if isinstance(data, list): + pairs = _pairs_from_message_list(data) + if pairs: + return pairs + + for value in data: + pairs = _extract_structured_pairs(value) + if pairs: + return pairs + + return [] + + +def _extract_script_pairs(soup: BeautifulSoup) -> List[MessagePair]: + for script in soup.find_all("script"): + script_text = script.string or script.get_text() + if not script_text: + continue + + candidates = [] + if script.get("id") == "__NEXT_DATA__": + candidates.append(script_text) + + match = re.search( + r"__PRELOADED_STATE__\s*=\s*(\{.*\})", + script_text, + re.DOTALL, + ) + if match: + candidates.append(match.group(1)) + + for candidate in candidates: + try: + data = json.loads(candidate) + except json.JSONDecodeError as exc: + logger.warning("Failed to parse chat structured state: %s", exc) + continue + + pairs = _extract_structured_pairs(data) + if pairs: + return pairs + + return [] + + +def _message_pairs_from_dom(user_blocks: List[Any], agent_blocks: List[Any]) -> List[MessagePair]: + pairs: List[MessagePair] = [] + + for user_block, agent_block in zip(user_blocks, agent_blocks): + user_query = user_block.get_text(separator="\n", strip=True) + agent_response = agent_block.get_text(separator="\n", strip=True) + if user_query and agent_response: + pairs.append(MessagePair( + user_query=user_query, + agent_response=agent_response, + )) + + return pairs + + +def _looks_unavailable(html: str) -> bool: + lowered = html.lower() + markers = ( + "conversation is private", + "private conversation", + "sign in to request access", + "sign in to view", + "does not exist", + "not found", + "no longer available", + "unable to load", + ) + return any(marker in lowered for marker in markers) + + async def _render_chat_share(url: str) -> tuple[str, str]: return await asyncio.to_thread(_render_chat_share_sync, url) @@ -128,8 +276,6 @@ async def _render_chat_share(url: str) -> tuple[str, str]: # reuse it across scrape requests. The browser is thread-safe when each # request uses its own BrowserContext. -import threading - _browser_lock = threading.Lock() _pw_instance = None _browser_instance = None @@ -239,51 +385,44 @@ def _extract_chat_pairs(url: str, html: str) -> tuple[str, str, List[MessagePair if provider == "chatgpt": user_msgs = soup.find_all("div", {"data-message-author-role": "user"}) asst_msgs = soup.find_all("div", {"data-message-author-role": "assistant"}) - for u, a in zip(user_msgs, asst_msgs): - pairs.append(MessagePair( - user_query=u.get_text(separator="\n").strip(), - agent_response=a.get_text(separator="\n").strip(), - )) + pairs = _message_pairs_from_dom(user_msgs, asst_msgs) if pairs: extraction_method = "dom" elif provider == "claude": - script_state = soup.find("script", string=re.compile(r"__PRELOADED_STATE__")) - if script_state and script_state.string: - try: - match = re.search( - r"__PRELOADED_STATE__\s*=\s*(\{.*?\});", - script_state.string, - re.DOTALL, - ) - if match: - data = json.loads(match.group(1)) - messages = data.get("chat", {}).get("messages", []) - current_user = "" - for msg in messages: - if msg.get("sender") == "human": - current_user = msg.get("text", "") - elif msg.get("sender") == "assistant": - pairs.append(MessagePair( - user_query=current_user, - agent_response=msg.get("text", ""), - )) - current_user = "" - if pairs: - extraction_method = "structured" - except Exception as exc: - logger.warning("Failed to parse Claude preloaded state: %s", exc) + pairs = _extract_script_pairs(soup) + if pairs: + extraction_method = "structured" + else: + user_blocks = soup.select( + "[data-testid*='human'], [data-testid*='user'], .font-user-message" + ) + asst_blocks = soup.select( + "[data-testid*='assistant'], [data-testid*='claude'], " + ".font-claude-message" + ) + pairs = _message_pairs_from_dom(user_blocks, asst_blocks) + if pairs: + extraction_method = "dom" elif provider == "gemini": - user_blocks = soup.select("message-content[role='user'], div.user-query") - model_blocks = soup.select("message-content[role='model'], div.model-response") - for u, m in zip(user_blocks, model_blocks): - pairs.append(MessagePair( - user_query=u.get_text(separator="\n").strip(), - agent_response=m.get_text(separator="\n").strip(), - )) + user_blocks = soup.select( + "message-content[role='user'], div.user-query, [data-testid*='user']" + ) + model_blocks = soup.select( + "message-content[role='model'], div.model-response, " + "[data-testid*='model'], [data-testid*='assistant']" + ) + pairs = _message_pairs_from_dom(user_blocks, model_blocks) if pairs: extraction_method = "dom" + else: + pairs = _extract_script_pairs(soup) + if pairs: + extraction_method = "structured" + + if not pairs and provider != "unknown" and _looks_unavailable(html): + extraction_method = "unavailable" if not pairs and provider == "unknown": paragraphs = [ diff --git a/tests/test_api_memory_scrape.py b/tests/test_api_memory_scrape.py new file mode 100644 index 0000000..d280c10 --- /dev/null +++ b/tests/test_api_memory_scrape.py @@ -0,0 +1,114 @@ +import os + +os.environ.setdefault("PINECONE_API_KEY", "test-pinecone-key") +os.environ.setdefault("NEO4J_PASSWORD", "test-neo4j-password") +os.environ.setdefault("GEMINI_API_KEY", "test-gemini-key") + +from src.api.routes.memory import _detect_chat_provider, _extract_chat_pairs + + +def test_detects_supported_public_chat_share_providers() -> None: + assert _detect_chat_provider("https://chatgpt.com/share/abc") == "chatgpt" + assert _detect_chat_provider("https://chat.openai.com/share/abc") == "chatgpt" + assert _detect_chat_provider("https://claude.ai/share/abc") == "claude" + assert _detect_chat_provider("https://gemini.google.com/share/abc") == "gemini" + assert _detect_chat_provider("https://example.com/share/abc") == "unknown" + + +def test_extracts_claude_pairs_from_next_data_script() -> None: + html = """ + +
+ + + + """ + + provider, method, pairs = _extract_chat_pairs("https://claude.ai/share/abc", html) + + assert provider == "claude" + assert method == "structured" + assert len(pairs) == 1 + assert pairs[0].user_query == "Summarize this release note." + assert pairs[0].agent_response == "Here is a short summary." + + +def test_extracts_claude_preloaded_state_when_message_contains_brace_semicolon() -> None: + html = r""" + + + + + + """ + + provider, method, pairs = _extract_chat_pairs("https://claude.ai/share/abc", html) + + assert provider == "claude" + assert method == "structured" + assert len(pairs) == 1 + assert pairs[0].agent_response == "Use const value = {}; then keep explaining." + + +def test_extracts_gemini_pairs_from_public_share_dom() -> None: + html = """ + + +Sign in to request access to this shared conversation.
+ + + """ + + provider, method, pairs = _extract_chat_pairs("https://claude.ai/share/private", html) + + assert provider == "claude" + assert method == "unavailable" + assert pairs == []