diff --git a/src/agents/judge.py b/src/agents/judge.py index f237ebd..db3e26c 100644 --- a/src/agents/judge.py +++ b/src/agents/judge.py @@ -31,6 +31,13 @@ Operation, OperationType, ) +from src.schemas.code import ( + code_annotation_content_hash, + code_annotation_fields_from_storage_content, + code_annotation_identity_key, + snippet_fields_from_storage_content, + snippet_identity_hash, +) from src.storage.base import BaseVectorStore, SearchResult @@ -193,6 +200,10 @@ async def arun_deterministic(self, state: Dict[str, Any]) -> JudgeResult: result = await self._deterministic_profile(new_items, user_id) elif domain == JudgeDomain.TEMPORAL: result = await self._deterministic_temporal(new_items, user_id) + elif domain == JudgeDomain.CODE: + result = await self._deterministic_code(new_items, user_id) + elif domain == JudgeDomain.SNIPPET: + result = await self._deterministic_snippet(new_items, user_id) else: self.logger.warning( "Deterministic judge unsupported for %s; falling back to LLM judge.", @@ -482,6 +493,97 @@ async def _deterministic_temporal( return JudgeResult(operations=operations, confidence=1.0) + async def _deterministic_code( + self, new_items: list, user_id: str, + ) -> JudgeResult: + unique_items: dict[str, tuple[str, dict[str, Any]]] = {} + for item in new_items: + content = str(item) + fields = code_annotation_fields_from_storage_content(content) + unique_items[code_annotation_identity_key(fields)] = (content, fields) + + async def _process_one(content: str, fields: dict[str, Any]) -> Operation: + match = await self._lookup_metadata_match({ + "user_id": user_id, + "domain": JudgeDomain.CODE.value, + "annotation_key": code_annotation_identity_key(fields), + }) + + if match is None: + return Operation( + type=OperationType.ADD, + content=content, + reason="No code annotation with the same repo/target/type key.", + ) + + incoming_hash = code_annotation_content_hash(fields) + existing_hash = str((match.metadata or {}).get("annotation_hash", "")) + if incoming_hash == existing_hash: + return Operation( + type=OperationType.NOOP, + content=content, + embedding_id=match.id, + reason="Existing code annotation is unchanged.", + ) + return Operation( + type=OperationType.UPDATE, + content=content, + embedding_id=match.id, + reason="Existing code annotation target has updated content.", + ) + + operations = await asyncio.gather(*( + _process_one(content, fields) + for content, fields in unique_items.values() + )) + return JudgeResult(operations=operations, confidence=1.0) + + async def _deterministic_snippet( + self, new_items: list, user_id: str, + ) -> JudgeResult: + unique_items: dict[str, tuple[str, dict[str, Any]]] = {} + for item in new_items: + content = str(item) + fields = snippet_fields_from_storage_content(content) + unique_items[snippet_identity_hash(fields)] = (content, fields) + + async def _process_one(content: str, fields: dict[str, Any]) -> Operation: + match = await self._lookup_metadata_match({ + "user_id": user_id, + "domain": JudgeDomain.SNIPPET.value, + "snippet_hash": snippet_identity_hash(fields), + }) + + if match is None: + return Operation( + type=OperationType.ADD, + content=content, + reason="No snippet with the same normalized code/content identity.", + ) + return Operation( + type=OperationType.NOOP, + content=content, + embedding_id=match.id, + reason="Same snippet was already stored for this user.", + ) + + operations = await asyncio.gather(*( + _process_one(content, fields) + for content, fields in unique_items.values() + )) + return JudgeResult(operations=operations, confidence=1.0) + + async def _lookup_metadata_match( + self, filters: Dict[str, Any], + ) -> Optional[SearchResult]: + if not self.vector_store: + return None + search_fn = getattr(self.vector_store, "search_by_metadata", None) + if search_fn is None: + return None + results = await asyncio.to_thread(search_fn, filters=filters, top_k=1) + return _first_match(results or []) + # -- Response parsing -------------------------------------------------- def _parse_response( diff --git a/src/pipelines/ingest.py b/src/pipelines/ingest.py index 9833d67..1bf8d66 100644 --- a/src/pipelines/ingest.py +++ b/src/pipelines/ingest.py @@ -82,7 +82,7 @@ ) from src.schemas.events import EventResult from src.schemas.image import ImageResult -from src.schemas.judge import JudgeDomain, JudgeResult, OperationType +from src.schemas.judge import JudgeDomain, JudgeResult from src.schemas.profile import ProfileResult from src.schemas.summary import SummaryResult from src.schemas.weaver import WeaverResult @@ -765,7 +765,13 @@ async def _node_extract_code(self, state: IngestState) -> Dict[str, Any]: ] all_items.append(" | ".join(parts)) - judge_result = await self.judge.arun({ + code_judge = JudgeAgent( + model=self.judge.model, + vector_store=self.code_vector_store, + graph_event_search=self._graph_event_search_wrapper, + top_k=self.judge.top_k, + ) + judge_result = await code_judge.arun_deterministic({ "domain": JudgeDomain.CODE, "new_items": all_items, "user_id": user_id, @@ -805,15 +811,19 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]: ] all_items.append(" | ".join(parts)) - judge_result = await self.judge.arun({ + self.weaver.snippet_vector_store = self._get_snippet_store(user_id) + snippet_judge = JudgeAgent( + model=self.judge.model, + vector_store=self.weaver.snippet_vector_store, + graph_event_search=self._graph_event_search_wrapper, + top_k=self.judge.top_k, + ) + judge_result = await snippet_judge.arun_deterministic({ "domain": JudgeDomain.SNIPPET, "new_items": all_items, "user_id": user_id, }) - # Bind the user-scoped snippet store before executing - self.weaver.snippet_vector_store = self._get_snippet_store(user_id) - weaver_result = await self.weaver.execute( judge_result=judge_result, domain=JudgeDomain.SNIPPET, diff --git a/src/pipelines/weaver.py b/src/pipelines/weaver.py index 55d55cb..21cac2e 100644 --- a/src/pipelines/weaver.py +++ b/src/pipelines/weaver.py @@ -23,6 +23,13 @@ Operation, OperationType, ) +from src.schemas.code import ( + code_annotation_fields_from_storage_content, + code_annotation_pinecone_metadata, + snippet_fields_from_storage_content, + snippet_pinecone_metadata, + snippet_search_text, +) from src.schemas.weaver import ExecutedOp, OpStatus, WeaverResult from src.storage.base import BaseVectorStore @@ -573,15 +580,7 @@ async def _code_add(self, op: Operation, user_id: str) -> ExecutedOp: import asyncio loop = asyncio.get_running_loop() embedding = await loop.run_in_executor(None, self.embed_fn, op.content) - metadata: Dict[str, Any] = { - "user_id": user_id, - "domain": "code", - "annotation_type": parsed.get("annotation_type", ""), - "target_symbol": parsed.get("target_symbol", ""), - "target_file": parsed.get("target_file", ""), - "repo": parsed.get("repo", ""), - "severity": parsed.get("severity", ""), - } + metadata = code_annotation_pinecone_metadata(user_id, parsed) from functools import partial ids = await loop.run_in_executor( @@ -633,15 +632,7 @@ async def _code_update(self, op: Operation, user_id: str) -> ExecutedOp: import asyncio loop = asyncio.get_running_loop() embedding = await loop.run_in_executor(None, self.embed_fn, op.content) - metadata: Dict[str, Any] = { - "user_id": user_id, - "domain": "code", - "annotation_type": parsed.get("annotation_type", ""), - "target_symbol": parsed.get("target_symbol", ""), - "target_file": parsed.get("target_file", ""), - "repo": parsed.get("repo", ""), - "severity": parsed.get("severity", ""), - } + metadata = code_annotation_pinecone_metadata(user_id, parsed) from functools import partial success = await loop.run_in_executor( @@ -725,20 +716,12 @@ async def _snippet_add(self, op: Operation, user_id: str) -> ExecutedOp: ) parsed = _parse_snippet_content(op.content) - searchable = parsed.get("content", op.content) + searchable = snippet_search_text(parsed) import asyncio loop = asyncio.get_running_loop() embedding = await loop.run_in_executor(None, self.embed_fn, searchable) - metadata: Dict[str, Any] = { - "user_id": user_id, - "domain": "snippet", - "code_snippet": parsed.get("code_snippet", ""), - "language": parsed.get("language", ""), - "snippet_type": parsed.get("snippet_type", "algorithm"), - "tags": parsed.get("tags", ""), - "source": "chat", - } + metadata = snippet_pinecone_metadata(user_id, parsed) from functools import partial ids = await loop.run_in_executor( @@ -771,20 +754,12 @@ async def _snippet_update(self, op: Operation, user_id: str) -> ExecutedOp: ) parsed = _parse_snippet_content(op.content) - searchable = parsed.get("content", op.content) + searchable = snippet_search_text(parsed) import asyncio loop = asyncio.get_running_loop() embedding = await loop.run_in_executor(None, self.embed_fn, searchable) - metadata: Dict[str, Any] = { - "user_id": user_id, - "domain": "snippet", - "code_snippet": parsed.get("code_snippet", ""), - "language": parsed.get("language", ""), - "snippet_type": parsed.get("snippet_type", "algorithm"), - "tags": parsed.get("tags", ""), - "source": "chat", - } + metadata = snippet_pinecone_metadata(user_id, parsed) from functools import partial success = await loop.run_in_executor( @@ -911,29 +886,7 @@ def _parse_snippet_content(content: str) -> Dict[str, str]: Falls back gracefully if the content doesn't match. """ - parts = [p.strip() for p in content.split(" | ")] - result: Dict[str, str] = {} - - if len(parts) >= 5: - result["content"] = parts[0] - result["code_snippet"] = parts[1] - result["language"] = parts[2] - result["snippet_type"] = parts[3] - result["tags"] = parts[4] - elif len(parts) >= 3: - result["content"] = parts[0] - result["code_snippet"] = parts[1] - result["language"] = parts[2] - result["snippet_type"] = "algorithm" - result["tags"] = "" - else: - result["content"] = content - result["code_snippet"] = "" - result["language"] = "" - result["snippet_type"] = "algorithm" - result["tags"] = "" - - return result + return snippet_fields_from_storage_content(content) def _parse_code_annotation_content(content: str) -> Dict[str, str]: @@ -945,21 +898,4 @@ def _parse_code_annotation_content(content: str) -> Dict[str, str]: Falls back gracefully if the content doesn't match the expected format (treats the entire string as the annotation content). """ - parts = [p.strip() for p in content.split("|")] - result: Dict[str, str] = {} - - if len(parts) >= 6: - result["annotation_type"] = parts[0] or "explanation" - result["target_symbol"] = parts[1] - result["target_file"] = parts[2] - result["repo"] = parts[3] - result["severity"] = parts[4] - result["content"] = parts[5] - elif len(parts) >= 2: - result["annotation_type"] = parts[0] or "explanation" - result["content"] = " | ".join(parts[1:]) - else: - result["content"] = content - result["annotation_type"] = "explanation" - - return result + return code_annotation_fields_from_storage_content(content) diff --git a/src/schemas/code.py b/src/schemas/code.py index e989f79..8975929 100644 --- a/src/schemas/code.py +++ b/src/schemas/code.py @@ -12,8 +12,9 @@ from __future__ import annotations +import hashlib from enum import Enum -from typing import List, Optional +from typing import Any, List, Optional from pydantic import BaseModel, Field @@ -290,3 +291,152 @@ def annotations_namespace(org_id: str) -> str: def snippets_namespace(user_id: str) -> str: return f"{user_id}:snippets" + + +# --------------------------------------------------------------------------- +# Pinecone identity helpers for code/snippet memory +# --------------------------------------------------------------------------- + +def normalize_code_text(value: str | None) -> str: + """Normalize code only for stable identity keys, not for display/storage.""" + text = (value or "").replace("\r\n", "\n").replace("\r", "\n").strip() + return "\n".join(line.rstrip() for line in text.split("\n")) + + +def normalize_lookup_text(value: Any) -> str: + return " ".join(str(value or "").strip().lower().split()) + + +def stable_hash(*parts: Any) -> str: + joined = "\x1f".join(normalize_lookup_text(part) for part in parts) + return hashlib.sha256(joined.encode("utf-8")).hexdigest() + + +def snippet_fields_from_storage_content(content: str) -> dict[str, str]: + """Parse the pipe-delimited snippet string emitted by the ingest pipeline.""" + parts = [p.strip() for p in content.split(" | ")] + if len(parts) >= 5: + return { + "content": parts[0], + "code_snippet": parts[1].replace("\\n", "\n"), + "language": parts[2], + "snippet_type": parts[3] or SnippetType.ALGORITHM.value, + "tags": parts[4], + } + if len(parts) >= 3: + return { + "content": parts[0], + "code_snippet": parts[1].replace("\\n", "\n"), + "language": parts[2], + "snippet_type": SnippetType.ALGORITHM.value, + "tags": "", + } + return { + "content": content, + "code_snippet": "", + "language": "", + "snippet_type": SnippetType.ALGORITHM.value, + "tags": "", + } + + +def snippet_identity_hash(fields: dict[str, Any]) -> str: + code = normalize_code_text(fields.get("code_snippet")) + content = normalize_lookup_text(fields.get("content")) + language = normalize_lookup_text(fields.get("language")) + identity_body = code or content + return stable_hash(language, identity_body) + + +def snippet_search_text(fields: dict[str, Any]) -> str: + """Return the text embedded for semantic snippet search.""" + content = str(fields.get("content") or "").strip() + language = str(fields.get("language") or "").strip() + tags = fields.get("tags") or "" + tags_text = ", ".join(tags) if isinstance(tags, list) else str(tags) + parts = [content] + if language: + parts.append(f"language: {language}") + if tags_text: + parts.append(f"tags: {tags_text}") + return "\n".join(part for part in parts if part) + + +def snippet_pinecone_metadata(user_id: str, fields: dict[str, Any]) -> dict[str, Any]: + tags = fields.get("tags") or "" + if isinstance(tags, list): + tags = ",".join(str(tag).strip() for tag in tags if str(tag).strip()) + return { + "user_id": user_id, + "domain": "snippet", + "snippet_hash": snippet_identity_hash(fields), + "code_snippet": str(fields.get("code_snippet") or ""), + "language": normalize_lookup_text(fields.get("language")), + "snippet_type": str(fields.get("snippet_type") or SnippetType.ALGORITHM.value), + "tags": str(tags), + "source": str(fields.get("source") or SnippetSource.CHAT.value), + } + + +def code_annotation_fields_from_storage_content(content: str) -> dict[str, str]: + """Parse the pipe-delimited code annotation string emitted by ingest.""" + parts = [p.strip() for p in content.split("|")] + if len(parts) >= 6: + return { + "annotation_type": parts[0] or AnnotationType.EXPLANATION.value, + "target_symbol": parts[1], + "target_file": parts[2], + "repo": parts[3], + "severity": parts[4], + "content": " | ".join(parts[5:]).strip(), + } + if len(parts) >= 2: + return { + "annotation_type": parts[0] or AnnotationType.EXPLANATION.value, + "target_symbol": "", + "target_file": "", + "repo": "", + "severity": "", + "content": " | ".join(parts[1:]).strip(), + } + return { + "annotation_type": AnnotationType.EXPLANATION.value, + "target_symbol": "", + "target_file": "", + "repo": "", + "severity": "", + "content": content, + } + + +def code_annotation_identity_key(fields: dict[str, Any]) -> str: + target = fields.get("target_symbol") or fields.get("target_file") or "" + return "|".join([ + normalize_lookup_text(fields.get("repo")), + normalize_lookup_text(target), + normalize_lookup_text(fields.get("annotation_type")), + ]) + + +def code_annotation_content_hash(fields: dict[str, Any]) -> str: + return stable_hash( + code_annotation_identity_key(fields), + fields.get("severity"), + fields.get("content"), + ) + + +def code_annotation_pinecone_metadata( + user_id: str, fields: dict[str, Any], +) -> dict[str, Any]: + return { + "user_id": user_id, + "domain": "code", + "annotation_key": code_annotation_identity_key(fields), + "annotation_hash": code_annotation_content_hash(fields), + "annotation_type": str(fields.get("annotation_type") or ""), + "target_symbol": str(fields.get("target_symbol") or ""), + "target_file": str(fields.get("target_file") or ""), + "repo": str(fields.get("repo") or ""), + "severity": str(fields.get("severity") or ""), + } diff --git a/tests/test_deterministic_memory_layer.py b/tests/test_deterministic_memory_layer.py index fe659c0..bcb1937 100644 --- a/tests/test_deterministic_memory_layer.py +++ b/tests/test_deterministic_memory_layer.py @@ -23,6 +23,12 @@ sys.modules.setdefault("langchain_core.messages", messages) from src.agents.judge import JudgeAgent +from src.schemas.code import ( + code_annotation_fields_from_storage_content, + code_annotation_pinecone_metadata, + snippet_fields_from_storage_content, + snippet_pinecone_metadata, +) from src.schemas.judge import JudgeDomain, OperationType from src.schemas.weaver import OpStatus from src.storage.base import SearchResult @@ -379,3 +385,151 @@ async def test_temporal_memory_layer_updates_same_date_changed_details(): assert weaver_result.executed[0].status == OpStatus.SUCCESS assert graph.updated[0][1] == "04-24" assert graph.events[("04-24", "demo")]["desc"] == "Updated product demo" + + +@pytest.mark.asyncio +async def test_snippet_deterministic_judge_noops_same_snippet_across_sessions(): + store = FakeVectorStore() + content = "Binary search helper | def bs():\\n return 1 | python | utility | search" + fields = snippet_fields_from_storage_content(content) + store.seed( + "snippet-1", + "Binary search helper", + snippet_pinecone_metadata("user-1", fields), + ) + judge = JudgeAgent(model=ModelMustNotBeCalled(), vector_store=store) + + judge_result = await judge.arun_deterministic({ + "domain": "snippet", + "new_items": [ + "Same helper pasted later | def bs():\n return 1 | Python | utility | dsa" + ], + "user_id": "user-1", + }) + + assert judge_result.confidence == 1.0 + assert judge_result.operations[0].type == OperationType.NOOP + assert judge_result.operations[0].embedding_id == "snippet-1" + + weaver = Weaver(vector_store=store, embed_fn=fake_embed) + weaver_result = await weaver.execute( + judge_result, + JudgeDomain.SNIPPET, + "user-1", + ) + + assert weaver_result.total == 0 + assert not store.add_calls + + +@pytest.mark.asyncio +async def test_snippet_deterministic_judge_deduplicates_incoming_batch(): + store = FakeVectorStore() + judge = JudgeAgent(model=ModelMustNotBeCalled(), vector_store=store) + + judge_result = await judge.arun_deterministic({ + "domain": "snippet", + "new_items": [ + "Binary search helper | def bs():\n return 1 | python | utility | search", + "Binary search helper again | def bs():\n return 1 | Python | utility | dsa", + ], + "user_id": "user-1", + }) + + assert judge_result.confidence == 1.0 + assert len(judge_result.operations) == 1 + assert judge_result.operations[0].type == OperationType.ADD + + +@pytest.mark.asyncio +async def test_snippet_weaver_stores_identity_hash_and_search_text(): + store = FakeVectorStore() + judge_result = await JudgeAgent( + model=ModelMustNotBeCalled(), + vector_store=store, + ).arun_deterministic({ + "domain": "snippet", + "new_items": [ + "Binary search helper | def bs():\\n return 1 | python | utility | search,array" + ], + "user_id": "user-1", + }) + + weaver = Weaver(vector_store=store, embed_fn=fake_embed) + weaver_result = await weaver.execute( + judge_result, + JudgeDomain.SNIPPET, + "user-1", + ) + + assert weaver_result.succeeded == 1 + record = store.records[weaver_result.executed[0].new_id] + assert record["text"] == "Binary search helper\nlanguage: python\ntags: search,array" + assert record["metadata"]["domain"] == "snippet" + assert len(record["metadata"]["snippet_hash"]) == 64 + assert record["metadata"]["code_snippet"] == "def bs():\n return 1" + + +@pytest.mark.asyncio +async def test_code_deterministic_judge_updates_same_target_annotation(): + store = FakeVectorStore() + existing = ( + "bug_report | Auth.login | src/auth.py | api | high | " + "Token refresh can fail" + ) + store.seed( + "code-1", + existing, + code_annotation_pinecone_metadata( + "user-1", + code_annotation_fields_from_storage_content(existing), + ), + ) + judge = JudgeAgent(model=ModelMustNotBeCalled(), vector_store=store) + + changed = ( + "bug_report | Auth.login | src/auth.py | api | high | " + "Token refresh can fail after session rotation" + ) + judge_result = await judge.arun_deterministic({ + "domain": "code", + "new_items": [changed], + "user_id": "user-1", + }) + + assert judge_result.operations[0].type == OperationType.UPDATE + assert judge_result.operations[0].embedding_id == "code-1" + + weaver = Weaver(vector_store=store, embed_fn=fake_embed) + weaver_result = await weaver.execute(judge_result, JudgeDomain.CODE, "user-1") + + assert weaver_result.succeeded == 1 + assert store.records["code-1"]["metadata"]["annotation_key"] == ( + "api|auth.login|bug_report" + ) + assert len(store.records["code-1"]["metadata"]["annotation_hash"]) == 64 + + +@pytest.mark.asyncio +async def test_code_deterministic_judge_deduplicates_incoming_batch(): + store = FakeVectorStore() + judge = JudgeAgent(model=ModelMustNotBeCalled(), vector_store=store) + + first = ( + "bug_report | Auth.login | src/auth.py | api | high | " + "Token refresh can fail" + ) + changed = ( + "bug_report | Auth.login | src/auth.py | api | high | " + "Token refresh can fail after session rotation" + ) + judge_result = await judge.arun_deterministic({ + "domain": "code", + "new_items": [first, changed], + "user_id": "user-1", + }) + + assert judge_result.confidence == 1.0 + assert len(judge_result.operations) == 1 + assert judge_result.operations[0].type == OperationType.ADD + assert judge_result.operations[0].content == changed diff --git a/tests/unit/test_schemas.py b/tests/unit/test_schemas.py index 3905c53..a1f2f62 100644 --- a/tests/unit/test_schemas.py +++ b/tests/unit/test_schemas.py @@ -9,6 +9,11 @@ ExtractedAnnotation, SnippetRecord, SnippetType, + code_annotation_fields_from_storage_content, + code_annotation_pinecone_metadata, + snippet_fields_from_storage_content, + snippet_pinecone_metadata, + snippet_search_text, annotations_namespace, snippets_namespace, symbols_namespace, @@ -84,3 +89,32 @@ def test_code_schema_enums_and_namespace_helpers(): assert symbols_namespace("acme", "payments") == "acme:payments:symbols" assert annotations_namespace("acme") == "acme:annotations" assert snippets_namespace("user-1") == "user-1:snippets" + + +def test_code_and_snippet_pinecone_metadata_have_stable_identity_keys(): + snippet_fields = snippet_fields_from_storage_content( + "Binary search helper | def bs():\\n return 1 | Python | utility | search,array" + ) + snippet_meta = snippet_pinecone_metadata("user-1", snippet_fields) + + same_snippet_fields = snippet_fields_from_storage_content( + "Binary search helper again | def bs():\n return 1 | python | utility | search" + ) + + assert snippet_search_text(snippet_fields) == ( + "Binary search helper\nlanguage: Python\ntags: search,array" + ) + assert snippet_meta["domain"] == "snippet" + assert snippet_meta["language"] == "python" + assert snippet_meta["snippet_hash"] == snippet_pinecone_metadata( + "user-1", same_snippet_fields, + )["snippet_hash"] + + annotation_fields = code_annotation_fields_from_storage_content( + "bug_report | Auth.login | src/auth.py | api | high | Token refresh can fail" + ) + annotation_meta = code_annotation_pinecone_metadata("user-1", annotation_fields) + + assert annotation_meta["domain"] == "code" + assert annotation_meta["annotation_key"] == "api|auth.login|bug_report" + assert len(annotation_meta["annotation_hash"]) == 64