Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions src/agents/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
Operation,
OperationType,
)
from src.schemas.code import (
code_annotation_content_hash,
code_annotation_fields_from_storage_content,
code_annotation_identity_key,
snippet_fields_from_storage_content,
snippet_identity_hash,
)
from src.storage.base import BaseVectorStore, SearchResult


Expand Down Expand Up @@ -193,6 +200,10 @@ async def arun_deterministic(self, state: Dict[str, Any]) -> JudgeResult:
result = await self._deterministic_profile(new_items, user_id)
elif domain == JudgeDomain.TEMPORAL:
result = await self._deterministic_temporal(new_items, user_id)
elif domain == JudgeDomain.CODE:
result = await self._deterministic_code(new_items, user_id)
elif domain == JudgeDomain.SNIPPET:
result = await self._deterministic_snippet(new_items, user_id)
else:
self.logger.warning(
"Deterministic judge unsupported for %s; falling back to LLM judge.",
Expand Down Expand Up @@ -482,6 +493,97 @@ async def _deterministic_temporal(

return JudgeResult(operations=operations, confidence=1.0)

async def _deterministic_code(
self, new_items: list, user_id: str,
) -> JudgeResult:
unique_items: dict[str, tuple[str, dict[str, Any]]] = {}
for item in new_items:
content = str(item)
fields = code_annotation_fields_from_storage_content(content)
unique_items[code_annotation_identity_key(fields)] = (content, fields)

async def _process_one(content: str, fields: dict[str, Any]) -> Operation:
match = await self._lookup_metadata_match({
"user_id": user_id,
"domain": JudgeDomain.CODE.value,
"annotation_key": code_annotation_identity_key(fields),
})

if match is None:
return Operation(
type=OperationType.ADD,
content=content,
reason="No code annotation with the same repo/target/type key.",
)

incoming_hash = code_annotation_content_hash(fields)
existing_hash = str((match.metadata or {}).get("annotation_hash", ""))
if incoming_hash == existing_hash:
return Operation(
type=OperationType.NOOP,
content=content,
embedding_id=match.id,
reason="Existing code annotation is unchanged.",
)
return Operation(
type=OperationType.UPDATE,
content=content,
embedding_id=match.id,
reason="Existing code annotation target has updated content.",
)

operations = await asyncio.gather(*(
_process_one(content, fields)
for content, fields in unique_items.values()
))
return JudgeResult(operations=operations, confidence=1.0)
Comment on lines +496 to +539
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _deterministic_code method processes items sequentially and lacks deduplication of the incoming new_items. If multiple identical annotations are extracted in a single turn, this will result in redundant operations and potential duplicate records in the vector store. It is recommended to deduplicate items by their identity key and use asyncio.gather to perform metadata lookups in parallel, maintaining consistency with the profile and temporal domains.

    async def _deterministic_code(
        self, new_items: list, user_id: str,
    ) -> JudgeResult:
        # Deduplicate items by identity key to prevent redundant operations
        unique_items: dict[str, tuple[str, dict]] = {}
        for item in new_items:
            content = str(item)
            fields = code_annotation_fields_from_storage_content(content)
            key = code_annotation_identity_key(fields)
            unique_items[key] = (content, fields)

        async def _process_one(content: str, fields: dict) -> Operation:
            match = await self._lookup_metadata_match({
                "user_id": user_id,
                "domain": JudgeDomain.CODE.value,
                "annotation_key": code_annotation_identity_key(fields),
            })

            if match is None:
                return Operation(
                    type=OperationType.ADD,
                    content=content,
                    reason="No code annotation with the same repo/target/type key.",
                )

            incoming_hash = code_annotation_content_hash(fields)
            existing_hash = str((match.metadata or {}).get("annotation_hash", ""))
            if incoming_hash == existing_hash:
                return Operation(
                    type=OperationType.NOOP,
                    content=content,
                    embedding_id=match.id,
                    reason="Existing code annotation is unchanged.",
                )
            else:
                return Operation(
                    type=OperationType.UPDATE,
                    content=content,
                    embedding_id=match.id,
                    reason="Existing code annotation target has updated content.",
                )

        tasks = [_process_one(c, f) for c, f in unique_items.values()]
        operations = await asyncio.gather(*tasks)
        return JudgeResult(operations=list(operations), confidence=1.0)


async def _deterministic_snippet(
self, new_items: list, user_id: str,
) -> JudgeResult:
unique_items: dict[str, tuple[str, dict[str, Any]]] = {}
for item in new_items:
content = str(item)
fields = snippet_fields_from_storage_content(content)
unique_items[snippet_identity_hash(fields)] = (content, fields)

async def _process_one(content: str, fields: dict[str, Any]) -> Operation:
match = await self._lookup_metadata_match({
"user_id": user_id,
"domain": JudgeDomain.SNIPPET.value,
"snippet_hash": snippet_identity_hash(fields),
})

if match is None:
return Operation(
type=OperationType.ADD,
content=content,
reason="No snippet with the same normalized code/content identity.",
)
return Operation(
type=OperationType.NOOP,
content=content,
embedding_id=match.id,
reason="Same snippet was already stored for this user.",
)

operations = await asyncio.gather(*(
_process_one(content, fields)
for content, fields in unique_items.values()
))
return JudgeResult(operations=operations, confidence=1.0)
Comment on lines +541 to +574
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to _deterministic_code, the _deterministic_snippet method should deduplicate new_items by their identity hash and parallelize the metadata lookups using asyncio.gather to improve performance and prevent duplicate operations.

    async def _deterministic_snippet(
        self, new_items: list, user_id: str,
    ) -> JudgeResult:
        # Deduplicate items by snippet hash to prevent redundant operations
        unique_items: dict[str, tuple[str, dict]] = {}
        for item in new_items:
            content = str(item)
            fields = snippet_fields_from_storage_content(content)
            h = snippet_identity_hash(fields)
            unique_items[h] = (content, fields)

        async def _process_one(content: str, fields: dict) -> Operation:
            match = await self._lookup_metadata_match({
                "user_id": user_id,
                "domain": JudgeDomain.SNIPPET.value,
                "snippet_hash": snippet_identity_hash(fields),
            })

            if match is None:
                return Operation(
                    type=OperationType.ADD,
                    content=content,
                    reason="No snippet with the same normalized code/content identity.",
                )
            else:
                return Operation(
                    type=OperationType.NOOP,
                    content=content,
                    embedding_id=match.id,
                    reason="Same snippet was already stored for this user.",
                )

        tasks = [_process_one(c, f) for c, f in unique_items.values()]
        operations = await asyncio.gather(*tasks)
        return JudgeResult(operations=list(operations), confidence=1.0)


async def _lookup_metadata_match(
self, filters: Dict[str, Any],
) -> Optional[SearchResult]:
if not self.vector_store:
return None
search_fn = getattr(self.vector_store, "search_by_metadata", None)
if search_fn is None:
return None
results = await asyncio.to_thread(search_fn, filters=filters, top_k=1)
return _first_match(results or [])

# -- Response parsing --------------------------------------------------

def _parse_response(
Expand Down
22 changes: 16 additions & 6 deletions src/pipelines/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
)
from src.schemas.events import EventResult
from src.schemas.image import ImageResult
from src.schemas.judge import JudgeDomain, JudgeResult, OperationType
from src.schemas.judge import JudgeDomain, JudgeResult
from src.schemas.profile import ProfileResult
from src.schemas.summary import SummaryResult
from src.schemas.weaver import WeaverResult
Expand Down Expand Up @@ -765,7 +765,13 @@ async def _node_extract_code(self, state: IngestState) -> Dict[str, Any]:
]
all_items.append(" | ".join(parts))

judge_result = await self.judge.arun({
code_judge = JudgeAgent(
model=self.judge.model,
vector_store=self.code_vector_store,
graph_event_search=self._graph_event_search_wrapper,
top_k=self.judge.top_k,
)
judge_result = await code_judge.arun_deterministic({
"domain": JudgeDomain.CODE,
"new_items": all_items,
"user_id": user_id,
Expand Down Expand Up @@ -805,15 +811,19 @@ async def _node_extract_snippet(self, state: IngestState) -> Dict[str, Any]:
]
all_items.append(" | ".join(parts))

judge_result = await self.judge.arun({
self.weaver.snippet_vector_store = self._get_snippet_store(user_id)
snippet_judge = JudgeAgent(
model=self.judge.model,
vector_store=self.weaver.snippet_vector_store,
graph_event_search=self._graph_event_search_wrapper,
top_k=self.judge.top_k,
)
judge_result = await snippet_judge.arun_deterministic({
"domain": JudgeDomain.SNIPPET,
"new_items": all_items,
"user_id": user_id,
})

# Bind the user-scoped snippet store before executing
self.weaver.snippet_vector_store = self._get_snippet_store(user_id)

weaver_result = await self.weaver.execute(
judge_result=judge_result,
domain=JudgeDomain.SNIPPET,
Expand Down
94 changes: 15 additions & 79 deletions src/pipelines/weaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
Operation,
OperationType,
)
from src.schemas.code import (
code_annotation_fields_from_storage_content,
code_annotation_pinecone_metadata,
snippet_fields_from_storage_content,
snippet_pinecone_metadata,
snippet_search_text,
)
from src.schemas.weaver import ExecutedOp, OpStatus, WeaverResult
from src.storage.base import BaseVectorStore

Expand Down Expand Up @@ -573,15 +580,7 @@ async def _code_add(self, op: Operation, user_id: str) -> ExecutedOp:
import asyncio
loop = asyncio.get_running_loop()
embedding = await loop.run_in_executor(None, self.embed_fn, op.content)
metadata: Dict[str, Any] = {
"user_id": user_id,
"domain": "code",
"annotation_type": parsed.get("annotation_type", ""),
"target_symbol": parsed.get("target_symbol", ""),
"target_file": parsed.get("target_file", ""),
"repo": parsed.get("repo", ""),
"severity": parsed.get("severity", ""),
}
metadata = code_annotation_pinecone_metadata(user_id, parsed)

from functools import partial
ids = await loop.run_in_executor(
Expand Down Expand Up @@ -633,15 +632,7 @@ async def _code_update(self, op: Operation, user_id: str) -> ExecutedOp:
import asyncio
loop = asyncio.get_running_loop()
embedding = await loop.run_in_executor(None, self.embed_fn, op.content)
metadata: Dict[str, Any] = {
"user_id": user_id,
"domain": "code",
"annotation_type": parsed.get("annotation_type", ""),
"target_symbol": parsed.get("target_symbol", ""),
"target_file": parsed.get("target_file", ""),
"repo": parsed.get("repo", ""),
"severity": parsed.get("severity", ""),
}
metadata = code_annotation_pinecone_metadata(user_id, parsed)

from functools import partial
success = await loop.run_in_executor(
Expand Down Expand Up @@ -725,20 +716,12 @@ async def _snippet_add(self, op: Operation, user_id: str) -> ExecutedOp:
)

parsed = _parse_snippet_content(op.content)
searchable = parsed.get("content", op.content)
searchable = snippet_search_text(parsed)
import asyncio
loop = asyncio.get_running_loop()
embedding = await loop.run_in_executor(None, self.embed_fn, searchable)

metadata: Dict[str, Any] = {
"user_id": user_id,
"domain": "snippet",
"code_snippet": parsed.get("code_snippet", ""),
"language": parsed.get("language", ""),
"snippet_type": parsed.get("snippet_type", "algorithm"),
"tags": parsed.get("tags", ""),
"source": "chat",
}
metadata = snippet_pinecone_metadata(user_id, parsed)

from functools import partial
ids = await loop.run_in_executor(
Expand Down Expand Up @@ -771,20 +754,12 @@ async def _snippet_update(self, op: Operation, user_id: str) -> ExecutedOp:
)

parsed = _parse_snippet_content(op.content)
searchable = parsed.get("content", op.content)
searchable = snippet_search_text(parsed)
import asyncio
loop = asyncio.get_running_loop()
embedding = await loop.run_in_executor(None, self.embed_fn, searchable)

metadata: Dict[str, Any] = {
"user_id": user_id,
"domain": "snippet",
"code_snippet": parsed.get("code_snippet", ""),
"language": parsed.get("language", ""),
"snippet_type": parsed.get("snippet_type", "algorithm"),
"tags": parsed.get("tags", ""),
"source": "chat",
}
metadata = snippet_pinecone_metadata(user_id, parsed)

from functools import partial
success = await loop.run_in_executor(
Expand Down Expand Up @@ -911,29 +886,7 @@ def _parse_snippet_content(content: str) -> Dict[str, str]:

Falls back gracefully if the content doesn't match.
"""
parts = [p.strip() for p in content.split(" | ")]
result: Dict[str, str] = {}

if len(parts) >= 5:
result["content"] = parts[0]
result["code_snippet"] = parts[1]
result["language"] = parts[2]
result["snippet_type"] = parts[3]
result["tags"] = parts[4]
elif len(parts) >= 3:
result["content"] = parts[0]
result["code_snippet"] = parts[1]
result["language"] = parts[2]
result["snippet_type"] = "algorithm"
result["tags"] = ""
else:
result["content"] = content
result["code_snippet"] = ""
result["language"] = ""
result["snippet_type"] = "algorithm"
result["tags"] = ""

return result
return snippet_fields_from_storage_content(content)


def _parse_code_annotation_content(content: str) -> Dict[str, str]:
Expand All @@ -945,21 +898,4 @@ def _parse_code_annotation_content(content: str) -> Dict[str, str]:
Falls back gracefully if the content doesn't match the expected format
(treats the entire string as the annotation content).
"""
parts = [p.strip() for p in content.split("|")]
result: Dict[str, str] = {}

if len(parts) >= 6:
result["annotation_type"] = parts[0] or "explanation"
result["target_symbol"] = parts[1]
result["target_file"] = parts[2]
result["repo"] = parts[3]
result["severity"] = parts[4]
result["content"] = parts[5]
elif len(parts) >= 2:
result["annotation_type"] = parts[0] or "explanation"
result["content"] = " | ".join(parts[1:])
else:
result["content"] = content
result["annotation_type"] = "explanation"

return result
return code_annotation_fields_from_storage_content(content)
Loading
Loading