From 588b623d30180575790c52af67435ce3dd71b91a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 11 May 2026 13:07:05 -0700 Subject: [PATCH] feat: Add GcsOffloader for asynchronously uploading content to Google Cloud Storage This change enables user with possibility to upload content to GCP PiperOrigin-RevId: 913845322 --- core/pom.xml | 4 + .../BigQueryAgentAnalyticsPlugin.java | 5 +- .../plugins/agentanalytics/GcsOffloader.java | 84 ++++++++++ .../adk/plugins/agentanalytics/Parser.java | 146 ++++++++++++++++-- .../plugins/agentanalytics/PluginState.java | 68 +++++++- .../BigQueryAgentAnalyticsPluginTest.java | 86 +++++++++++ .../agentanalytics/JsonFormatterTest.java | 75 +++++++-- .../plugins/agentanalytics/ParserTest.java | 8 +- .../agentanalytics/PluginStateTest.java | 47 +++++- pom.xml | 6 + 10 files changed, 491 insertions(+), 38 deletions(-) create mode 100644 core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java diff --git a/core/pom.xml b/core/pom.xml index 53fd51883..21ab13c15 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -217,6 +217,10 @@ arrow-memory-netty 17.0.0 + + org.apache.tika + tika-core + diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 59e09c8a7..566dbd5a4 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -253,7 +253,10 @@ private Completable logEvent( parseFuture = state .getParser() - .parse(content) + .parse( + content, + traceIds.traceId(), + traceIds.spanId() != null ? traceIds.spanId() : "no_span") .thenAccept( parsedContent -> { row.put( diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java new file mode 100644 index 000000000..8900a93dd --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/GcsOffloader.java @@ -0,0 +1,84 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.auth.Credentials; +import com.google.cloud.storage.BlobId; +import com.google.cloud.storage.BlobInfo; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.StorageOptions; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import org.jspecify.annotations.Nullable; + +/** Offloads content to GCS. */ +class GcsOffloader { + private final Storage storage; + private final String bucketName; + private final Executor executor; + private final boolean isStorageOverride; + + GcsOffloader( + String projectId, + String bucketName, + Executor executor, + @Nullable Credentials credentials, + @Nullable Storage storageOverride) { + if (storageOverride != null) { + this.isStorageOverride = true; + this.storage = storageOverride; + } else { + this.isStorageOverride = false; + StorageOptions.Builder builder = StorageOptions.newBuilder().setProjectId(projectId); + if (credentials != null) { + builder.setCredentials(credentials); + } + this.storage = builder.build().getService(); + } + this.bucketName = bucketName; + this.executor = executor; + } + + /** Async wrapper around blocking GCS upload for binary data. */ + CompletableFuture uploadContent(byte[] data, String contentType, String path) { + return CompletableFuture.supplyAsync( + () -> { + BlobId blobId = BlobId.of(bucketName, path); + BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setContentType(contentType).build(); + storage.create(blobInfo, data); + return String.format("gs://%s/%s", bucketName, path); + }, + executor); + } + + /** Async wrapper around blocking GCS upload for text data. */ + CompletableFuture uploadContent(String data, String contentType, String path) { + return uploadContent(data.getBytes(UTF_8), contentType, path); + } + + String getBucketName() { + return bucketName; + } + + void close() throws Exception { + if (storage != null && !isStorageOverride) { + storage.close(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java index 5db8be46c..f4eff09f7 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java @@ -19,6 +19,7 @@ import static com.google.adk.plugins.agentanalytics.JsonFormatter.mapper; import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncate; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncateAndAddSuffix; import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncateWithStatus; import com.fasterxml.jackson.annotation.JsonProperty; @@ -39,16 +40,43 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.tika.mime.MimeTypeException; +import org.apache.tika.mime.MimeTypes; import org.jspecify.annotations.Nullable; +import org.threeten.bp.Instant; +import org.threeten.bp.LocalDate; +import org.threeten.bp.ZoneOffset; /** Utility for parsing content for BigQuery logging. */ final class Parser { + private static final String DEFAULT_EXTENSION = ".bin"; + private static final int MAX_OFFLOADED_TEXT_LENGTH = 200; + private static final Logger logger = Logger.getLogger(Parser.class.getName()); + private static final int INLINE_TEXT_LIMIT = 32 * 1024; // 32KB limit + private static final String UPLOAD_FAILED_MESSAGE = "[UPLOAD FAILED]"; + private static final String MEDIA_OFFLOADED_MESSAGE = "[MEDIA OFFLOADED]"; private static final String BINARY_DATA_MESSAGE = "[BINARY DATA]"; - private final int maxLength; + private static final String TEXT_OFFLOADED_SUFFIX = "... [OFFLOADED]"; + private static final MimeTypes MIME_TYPES = MimeTypes.getDefaultMimeTypes(); - Parser(int maxLength) { + private final @Nullable GcsOffloader offloader; + private final int maxLength; + private final @Nullable String connectionId; + private final boolean logMultiModalContent; + + Parser( + @Nullable GcsOffloader offloader, + int maxLength, + @Nullable String connectionId, + boolean logMultiModalContent) { + this.offloader = offloader; this.maxLength = maxLength; + this.connectionId = connectionId; + this.logMultiModalContent = logMultiModalContent; } @AutoValue @@ -152,9 +180,11 @@ static ObjectRef create( * Parses content into JSON payload and content parts, matching Python implementation. * * @param content the content to parse + * @param traceId the trace ID for GCS path + * @param spanId the span ID for GCS path * @return a CompletableFuture of ParsedContent object */ - CompletableFuture parse(Object content) { + CompletableFuture parse(Object content, String traceId, String spanId) { if (content instanceof LlmRequest llmRequest) { ObjectNode jsonPayload = mapper.createObjectNode(); ArrayNode messages = mapper.createArrayNode(); @@ -162,13 +192,15 @@ CompletableFuture parse(Object content) { List contents = llmRequest.contents(); for (Content c : contents) { - futures.add(parseContentObject(c)); + futures.add(parseContentObject(c, traceId, spanId)); } CompletableFuture systemFuture = null; if (llmRequest.config().isPresent() && llmRequest.config().get().systemInstruction().isPresent()) { - systemFuture = parseContentObject(llmRequest.config().get().systemInstruction().get()); + systemFuture = + parseContentObject( + llmRequest.config().get().systemInstruction().get(), traceId, spanId); futures.add(systemFuture); } CompletableFuture finalSystemFuture = systemFuture; @@ -202,7 +234,7 @@ CompletableFuture parse(Object content) { } if (content instanceof LlmResponse llmResponse) { ObjectNode jsonPayload = mapper.createObjectNode(); - return parseContentObject(llmResponse.content().orElse(null)) + return parseContentObject(llmResponse.content().orElse(null), traceId, spanId) .thenApply( parsed -> { ObjectNode summaryNode = mapper.createObjectNode(); @@ -225,7 +257,7 @@ CompletableFuture parse(Object content) { }); } if (content instanceof Content || content instanceof Part) { - return parseContentObject(content) + return parseContentObject(content, traceId, spanId) .thenApply( parsed -> { ObjectNode summaryNode = mapper.createObjectNode(); @@ -249,10 +281,13 @@ CompletableFuture parse(Object content) { * Parses a Content or Part object into summary text and content parts. * * @param content the Content or Part object to parse + * @param traceId the trace ID for GCS path + * @param spanId the span ID for GCS path * @return a CompletableFuture of ParsedContentObject containing parts, summary, and truncation * flag */ - private CompletableFuture parseContentObject(Object content) { + private CompletableFuture parseContentObject( + Object content, String traceId, String spanId) { List parts; if (content instanceof Content c) { parts = c.parts().orElse(ImmutableList.of()); @@ -265,7 +300,7 @@ private CompletableFuture parseContentObject(Object content List> partFutures = new ArrayList<>(); for (int i = 0; i < parts.size(); i++) { - partFutures.add(processPart(parts.get(i), i)); + partFutures.add(processPart(parts.get(i), i, traceId, spanId)); } return CompletableFuture.allOf(partFutures.toArray(new CompletableFuture[0])) @@ -295,7 +330,8 @@ private CompletableFuture parseContentObject(Object content }); } - private CompletableFuture processPart(Part part, int index) { + private CompletableFuture processPart( + Part part, int index, String traceId, String spanId) { ContentPart.Builder partBuilder = ContentPart.builder() .setPartIndex(index) @@ -320,17 +356,89 @@ private CompletableFuture processPart(Part part, int index) { if (part.inlineData().isPresent()) { Blob blob = part.inlineData().get(); String mimeType = blob.mimeType().orElse("application/octet-stream"); - partBuilder.setText(BINARY_DATA_MESSAGE).setMimeType(mimeType); - return CompletableFuture.completedFuture( - TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); + if (logMultiModalContent && offloader != null) { + String ext = DEFAULT_EXTENSION; + try { + ext = MIME_TYPES.forName(mimeType).getExtension(); + } catch (MimeTypeException e) { + logger.log(Level.WARNING, "Failed to get extension for mime type " + mimeType, e); + } + String path = + String.format( + "%s/%s/%s_p%d_%s%s", + getLocalDate(), traceId, spanId, index, UUID.randomUUID(), ext); + return offloader + .uploadContent(blob.data().orElse(new byte[0]), mimeType, path) + .handle( + (uri, ex) -> { + if (ex != null) { + logger.log(Level.WARNING, "Failed to offload content to GCS", ex); + partBuilder.setText(UPLOAD_FAILED_MESSAGE); + } else { + ObjectNode details = mapper.createObjectNode(); + ObjectNode gcsMetadata = details.putObject("gcs_metadata"); + gcsMetadata.put("content_type", mimeType); + + partBuilder + .setStorageMode("GCS_REFERENCE") + .setUri(uri) + .setMimeType(mimeType) + .setText(MEDIA_OFFLOADED_MESSAGE) + .setObjectRef( + mapper.valueToTree(ObjectRef.create(uri, null, connectionId, details))); + } + return TruncationResult.create(mapper.valueToTree(partBuilder.build()), false); + }); + } else { + partBuilder.setText(BINARY_DATA_MESSAGE).setMimeType(mimeType); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); + } } // CASE C: Text if (part.text().isPresent()) { String text = part.text().get(); - TruncationResult res = truncateWithStatus(text, maxLength); - partBuilder.setText(res.node().asText()); - return CompletableFuture.completedFuture( - TruncationResult.create(mapper.valueToTree(partBuilder.build()), res.isTruncated())); + int textLen = Utf8.encodedLength(text); + int offloadThreshold = Math.min(INLINE_TEXT_LIMIT, maxLength); + + if (offloader != null && textLen > offloadThreshold) { + + String path = + String.format( + "%s/%s/%s_p%d_%s.txt", getLocalDate(), traceId, spanId, index, UUID.randomUUID()); + return offloader + .uploadContent(text, "text/plain", path) + .handle( + (uri, ex) -> { + if (ex != null) { + logger.log(Level.WARNING, "Failed to offload text to GCS", ex); + TruncationResult res = truncateWithStatus(text, maxLength); + partBuilder.setText(res.node().asText()); + return TruncationResult.create( + mapper.valueToTree(partBuilder.build()), res.isTruncated()); + } else { + ObjectNode details = mapper.createObjectNode(); + ObjectNode gcsMetadata = details.putObject("gcs_metadata"); + gcsMetadata.put("content_type", "text/plain"); + + partBuilder + .setStorageMode("GCS_REFERENCE") + .setUri(uri) + .setMimeType("text/plain") + .setText( + truncateAndAddSuffix( + text, MAX_OFFLOADED_TEXT_LENGTH, TEXT_OFFLOADED_SUFFIX)) + .setObjectRef( + mapper.valueToTree(ObjectRef.create(uri, null, connectionId, details))); + return TruncationResult.create(mapper.valueToTree(partBuilder.build()), true); + } + }); + } else { + TruncationResult res = truncateWithStatus(text, maxLength); + partBuilder.setText(res.node().asText()); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), res.isTruncated())); + } } if (part.functionCall().isPresent()) { FunctionCall fc = part.functionCall().get(); @@ -379,4 +487,8 @@ ArrayNode formatContentParts(Optional content) { } return partsArray; } + + private LocalDate getLocalDate() { + return Instant.now().atZone(ZoneOffset.UTC).toLocalDate(); + } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java index 0654fab5d..2870e7053 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java @@ -21,21 +21,36 @@ import java.util.Collection; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; +import org.jspecify.annotations.Nullable; import org.threeten.bp.Duration; +import org.threeten.bp.Instant; /** Manages state for the BigQueryAgentAnalyticsPlugin. */ class PluginState { private static final Logger logger = Logger.getLogger(PluginState.class.getName()); + private static final int GCS_OFFLOAD_CORE_POOL_SIZE = 2; + private static final int GCS_OFFLOAD_MAX_THREADS = 10; + // Max number of tasks in the queue before we start rejecting tasks and executing them in the + // caller thread. + private static final int GCS_OFFLOAD_QUEUE_SIZE = 100; + // Idle time before threads are terminated. + private static final int GCS_OFFLOAD_IDLE_TIME_SECONDS = 30; + private final BigQueryLoggerConfig config; private final ScheduledExecutorService executor; + private final ExecutorService offloadExecutor; private final BigQueryWriteClient writeClient; private static final AtomicLong threadCounter = new AtomicLong(0); // Map of invocation ID to BatchProcessor. @@ -45,6 +60,7 @@ class PluginState { private final ConcurrentHashMap traceManagers = new ConcurrentHashMap<>(); // Cache of invocation ID to Boolean indicating invocation ID has been processed. private final Cache processedInvocations; + private final GcsOffloader offloader; private final Parser parser; private final ConcurrentHashMap>> pendingTasks = new ConcurrentHashMap<>(); @@ -54,6 +70,7 @@ class PluginState { this.executor = Executors.newScheduledThreadPool( 2, r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement())); + this.offloadExecutor = createGcsOffloadThreadPool(); // One write client per plugin instance, shared by all invocations. this.writeClient = createWriteClient(config); this.processedInvocations = @@ -61,7 +78,25 @@ class PluginState { .maximumSize(10000) .expireAfterWrite(java.time.Duration.ofMinutes(10)) .build(); - this.parser = new Parser(config.maxContentLength()); + this.offloader = getGcsOffloader(config); + this.parser = + new Parser( + offloader, + config.maxContentLength(), + config.connectionId().orElse(null), + config.logMultiModalContent()); + } + + private static ExecutorService createGcsOffloadThreadPool() { + return new ThreadPoolExecutor( + GCS_OFFLOAD_CORE_POOL_SIZE, // The lower limit of threads. + GCS_OFFLOAD_MAX_THREADS, // The upper limit of threads. + GCS_OFFLOAD_IDLE_TIME_SECONDS, // Time to keep idle threads alive. + TimeUnit.SECONDS, + new ArrayBlockingQueue<>(GCS_OFFLOAD_QUEUE_SIZE), // workQueue: Hand off tasks directly. + r -> new Thread(r, "bq-analytics-plugin-offload-" + threadCounter.getAndIncrement()), + // Sensible rejection policy to execute tasks in the caller thread. + new ThreadPoolExecutor.CallerRunsPolicy()); } ScheduledExecutorService getExecutor() { @@ -142,6 +177,14 @@ BatchProcessor getBatchProcessor(String invocationId) { }); } + protected @Nullable GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { + if (config.gcsBucketName().isEmpty()) { + return null; + } + return new GcsOffloader( + config.projectId(), config.gcsBucketName(), offloadExecutor, config.credentials(), null); + } + Parser getParser() { return parser; } @@ -263,13 +306,34 @@ Completable close() { } try { executor.shutdown(); - if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + offloadExecutor.shutdown(); + long totalTimeoutMillis = config.shutdownTimeout().toMillis(); + Instant startTime = Instant.now(); + if (!executor.awaitTermination(totalTimeoutMillis, MILLISECONDS)) { executor.shutdownNow(); } + long elapsedTimeMillis = Duration.between(startTime, Instant.now()).toMillis(); + long remainingMillis = totalTimeoutMillis - elapsedTimeMillis; + if (remainingMillis > 0) { + if (!offloadExecutor.awaitTermination(remainingMillis, MILLISECONDS)) { + offloadExecutor.shutdownNow(); + } + } else { + offloadExecutor.shutdownNow(); + } } catch (InterruptedException e) { executor.shutdownNow(); + offloadExecutor.shutdownNow(); Thread.currentThread().interrupt(); } + + try { + if (offloader != null) { + offloader.close(); + } + } catch (Exception e) { + logger.log(Level.WARNING, "Failed to close GCS offloader", e); + } }); } } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index 836442cad..59a6bb245 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -17,6 +17,7 @@ package com.google.adk.plugins.agentanalytics; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -24,11 +25,13 @@ import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.adk.agents.BaseAgent; @@ -802,6 +805,7 @@ public void logEvent_handlesExceptionFromFormatter() throws Exception { (content, eventType) -> { throw new RuntimeException("Formatter error"); }; + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); PluginState formattedState = new PluginState(formattedConfig) { @@ -1042,6 +1046,88 @@ public void logEvent_createsUniqueProcessorPerInvocation() throws Exception { testExecutor.shutdown(); } + @Test + public void logEvent_offloadsToGcs_whenLargeContent() throws Exception { + GcsOffloader mockOffloader = mock(GcsOffloader.class); + when(mockOffloader.uploadContent(anyString(), anyString(), anyString())) + .thenReturn(CompletableFuture.completedFuture("gs://test-bucket/large.txt")); + + BigQueryLoggerConfig gcsConfig = config.toBuilder().gcsBucketName("test-bucket").build(); + PluginState gcsState = + new PluginState(gcsConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + + @Override + protected GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { + return mockOffloader; + } + }; + BigQueryAgentAnalyticsPlugin gcsPlugin = + new BigQueryAgentAnalyticsPlugin(gcsConfig, mockBigQuery, gcsState); + + // Large text (> 32KB default threshold) + String largeText = "a".repeat(40000); + Content content = Content.fromParts(Part.fromText(largeText)); + gcsPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + verify(mockOffloader, atLeastOnce()).uploadContent(anyString(), anyString(), anyString()); + + Map row = gcsState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + @SuppressWarnings("unchecked") // Test only + List contentParts = (List) row.get("content_parts"); + assertEquals("GCS_REFERENCE", contentParts.get(0).get("storage_mode").asText()); + assertEquals("gs://test-bucket/large.txt", contentParts.get(0).get("uri").asText()); + } + + @Test + public void logEvent_offloadsToGcs_whenMultimodalContent() throws Exception { + GcsOffloader mockOffloader = mock(GcsOffloader.class); + when(mockOffloader.uploadContent(any(byte[].class), anyString(), anyString())) + .thenReturn(CompletableFuture.completedFuture("gs://test-bucket/image.png")); + + BigQueryLoggerConfig gcsConfig = config.toBuilder().gcsBucketName("test-bucket").build(); + PluginState gcsState = + new PluginState(gcsConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + + @Override + protected GcsOffloader getGcsOffloader(BigQueryLoggerConfig config) { + return mockOffloader; + } + }; + BigQueryAgentAnalyticsPlugin gcsPlugin = + new BigQueryAgentAnalyticsPlugin(gcsConfig, mockBigQuery, gcsState); + + Content content = Content.fromParts(Part.fromBytes("test-data".getBytes(UTF_8), "image/png")); + gcsPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + verify(mockOffloader, atLeastOnce()).uploadContent(any(byte[].class), anyString(), anyString()); + + Map row = gcsState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + @SuppressWarnings("unchecked") // Test only + List contentParts = (List) row.get("content_parts"); + assertEquals("GCS_REFERENCE", contentParts.get(0).get("storage_mode").asText()); + assertEquals("gs://test-bucket/image.png", contentParts.get(0).get("uri").asText()); + } + private static class FakeAgent extends BaseAgent { FakeAgent(String name) { super(name, "description", null, null, null); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java index 4883438b6..663f5e5cd 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java @@ -16,16 +16,22 @@ package com.google.adk.plugins.agentanalytics; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Blob; import com.google.genai.types.Content; import com.google.genai.types.FileData; import com.google.genai.types.FunctionCall; @@ -33,6 +39,7 @@ import com.google.genai.types.Part; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,7 +56,8 @@ public void parse_llmRequest_populatesPrompt() throws Exception { Content.fromParts(Part.fromText("hello")).toBuilder().role("user").build())) .build(); - Parser.ParsedContent result = new Parser(100).parse(request).get(); + Parser.ParsedContent result = + new Parser(null, 100, null, true).parse(request, "trace", "span").get(); assertTrue(result.content().has("prompt")); ArrayNode prompt = (ArrayNode) result.content().get("prompt"); @@ -69,7 +77,8 @@ public void parse_llmRequest_populatesSystemPrompt() throws Exception { .build()) .build(); - Parser.ParsedContent result = new Parser(100).parse(request).get(); + Parser.ParsedContent result = + new Parser(null, 100, null, true).parse(request, "trace", "span").get(); assertTrue(result.content().has("system_prompt")); assertEquals("be helpful", result.content().get("system_prompt").asText()); @@ -79,7 +88,8 @@ public void parse_llmRequest_populatesSystemPrompt() throws Exception { @Test public void parse_string_truncates() throws Exception { String longString = "this is a very long string that should be truncated"; - Parser.ParsedContent result = new Parser(24).parse(longString).get(); + Parser.ParsedContent result = + new Parser(null, 24, null, true).parse(longString, "trace", "span").get(); assertTrue(result.isTruncated()); assertEquals("this is a ...[truncated]", result.content().asText()); @@ -89,7 +99,8 @@ public void parse_string_truncates() throws Exception { public void parse_map_truncatesNested() throws Exception { ImmutableMap map = ImmutableMap.of("key", "this is a very long value that should definitely be truncated"); - Parser.ParsedContent result = new Parser(24).parse(map).get(); + Parser.ParsedContent result = + new Parser(null, 24, null, true).parse(map, "trace", "span").get(); assertTrue(result.isTruncated()); assertEquals("this is a ...[truncated]", result.content().get("key").asText()); @@ -98,7 +109,8 @@ public void parse_map_truncatesNested() throws Exception { @Test public void parse_content_returnsSummary() throws Exception { Content content = Content.fromParts(Part.fromText("part 1"), Part.fromText("part 2")); - Parser.ParsedContent result = new Parser(100).parse(content).get(); + Parser.ParsedContent result = + new Parser(null, 100, null, true).parse(content, "trace", "span").get(); assertEquals("part 1 | part 2", result.content().get("text_summary").asText()); assertEquals(2, result.parts().size()); @@ -109,7 +121,8 @@ public void parse_content_withFileData() throws Exception { FileData fileData = FileData.builder().fileUri("gs://bucket/file.txt").mimeType("text/plain").build(); Content content = Content.fromParts(Part.builder().fileData(fileData).build()); - Parser.ParsedContent result = new Parser(100).parse(content).get(); + Parser.ParsedContent result = + new Parser(null, 100, null, true).parse(content, "trace", "span").get(); assertEquals(1, result.parts().size()); JsonNode partData = result.parts().get(0); @@ -122,7 +135,8 @@ public void parse_content_withFileData() throws Exception { public void parse_content_withFunctionCall() throws Exception { FunctionCall fc = FunctionCall.builder().name("myFunction").build(); Content content = Content.fromParts(Part.builder().functionCall(fc).build()); - Parser.ParsedContent result = new Parser(100).parse(content).get(); + Parser.ParsedContent result = + new Parser(null, 100, null, true).parse(content, "trace", "span").get(); assertEquals(1, result.parts().size()); JsonNode partData = result.parts().get(0); @@ -135,7 +149,8 @@ public void parse_content_withFunctionCall() throws Exception { public void parse_list_truncatesElements() throws Exception { List list = Arrays.asList("short", "this is a very long string that should be truncated"); - Parser.ParsedContent result = new Parser(24).parse(list).get(); + Parser.ParsedContent result = + new Parser(null, 24, null, true).parse(list, "trace", "span").get(); assertTrue(result.isTruncated()); JsonNode arrayNode = result.content(); @@ -145,6 +160,44 @@ public void parse_list_truncatesElements() throws Exception { assertEquals("this is a ...[truncated]", arrayNode.get(1).asText()); } + @Test + public void parse_withOffloader_offloadsLargeText() throws Exception { + GcsOffloader offloader = mock(GcsOffloader.class); + when(offloader.uploadContent(anyString(), anyString(), anyString())) + .thenReturn(CompletableFuture.completedFuture("gs://mock-bucket/path")); + + Content content = + Content.fromParts(Part.fromText("this text is longer than 10 characters".repeat(100))); + Parser.ParsedContent result = + new Parser(offloader, 10, "conn", true).parse(content, "trace", "span").get(); + + assertEquals(1, result.parts().size()); + JsonNode partData = result.parts().get(0); + assertEquals("GCS_REFERENCE", partData.get("storage_mode").asText()); + assertEquals("gs://mock-bucket/path", partData.get("uri").asText()); + assertTrue(partData.get("text").asText().contains("[OFFLOADED]")); + assertEquals("conn", partData.get("object_ref").get("authorizer").asText()); + } + + @Test + public void parse_withOffloader_offloadsBinaryData() throws Exception { + GcsOffloader offloader = mock(GcsOffloader.class); + when(offloader.uploadContent(any(byte[].class), anyString(), anyString())) + .thenReturn(CompletableFuture.completedFuture("gs://mock-bucket/image.png")); + + Blob blob = Blob.builder().data("fake-image".getBytes(UTF_8)).mimeType("image/png").build(); + Content content = Content.fromParts(Part.builder().inlineData(blob).build()); + Parser.ParsedContent result = + new Parser(offloader, 100, "conn", true).parse(content, "trace", "span").get(); + + assertEquals(1, result.parts().size()); + JsonNode partData = result.parts().get(0); + assertEquals("GCS_REFERENCE", partData.get("storage_mode").asText()); + assertEquals("gs://mock-bucket/image.png", partData.get("uri").asText()); + assertEquals("image/png", partData.get("mime_type").asText()); + assertEquals("[MEDIA OFFLOADED]", partData.get("text").asText()); + } + @Test public void truncate_variousInputs() { assertNull(JsonFormatter.truncate(null, 10)); @@ -188,7 +241,8 @@ public void parse_multibyteString_truncatesBasedOnBytes() throws Exception { // "こんにちはこんにちは" is 30 bytes, but 10 characters. String nihongo = "こんにちはこんにちは"; // With budget 20, effective budget is 6, so only 2 characters (6 bytes) should be kept. - Parser.ParsedContent result = new Parser(20).parse(nihongo).get(); + Parser.ParsedContent result = + new Parser(null, 20, null, true).parse(nihongo, "trace", "span").get(); assertTrue(result.isTruncated()); assertEquals("こん...[truncated]", result.content().asText()); @@ -197,7 +251,8 @@ public void parse_multibyteString_truncatesBasedOnBytes() throws Exception { @Test public void parse_multibyteContent_truncatesBasedOnBytes() throws Exception { Content content = Content.fromParts(Part.fromText("こんにちはこんにちは")); - Parser.ParsedContent result = new Parser(20).parse(content).get(); + Parser.ParsedContent result = + new Parser(null, 20, null, true).parse(content, "trace", "span").get(); assertTrue(result.isTruncated()); assertEquals("こん...[truncated]", result.content().get("text_summary").asText()); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java index 9bae03331..385e81082 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java @@ -38,13 +38,13 @@ public final class ParserTest { @Before public void setUp() { - parser = new Parser(100); + parser = new Parser(null, 100, "connectionId", true); } @Test public void parse_part_coversLine280() throws Exception { Part part = Part.fromText("test part"); - CompletableFuture future = parser.parse(part); + CompletableFuture future = parser.parse(part, "traceId", "spanId"); Parser.ParsedContent result = future.get(); assertEquals("{\"text_summary\":\"test part\"}", result.content().toString()); @@ -56,7 +56,7 @@ public void parse_part_coversLine280() throws Exception { public void parse_part_withInlineData_coversProcessPart() throws Exception { Blob blob = Blob.builder().mimeType("image/png").data(new byte[] {1, 2, 3}).build(); Part part = Part.builder().inlineData(blob).build(); - CompletableFuture future = parser.parse(part); + CompletableFuture future = parser.parse(part, "traceId", "spanId"); Parser.ParsedContent result = future.get(); assertEquals(1, result.parts().size()); @@ -104,7 +104,7 @@ public void parse_multipartContent_coversLine310() throws Exception { // Call private method using helper if necessary, but parseContentObject is private. // However, parse(Object content, ...) calls it. - CompletableFuture future = parser.parse(content); + CompletableFuture future = parser.parse(content, "traceId", "spanId"); Parser.ParsedContent result = future.get(); assertTrue(result.isTruncated()); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java index 444cc8a6d..8ca2195ab 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java @@ -17,6 +17,7 @@ package com.google.adk.plugins.agentanalytics; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; @@ -29,9 +30,12 @@ import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; import com.google.cloud.bigquery.storage.v1.StreamWriter; import java.io.IOException; +import java.lang.reflect.Field; import java.time.Duration; import java.time.Instant; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -65,10 +69,6 @@ protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { return mockWriteClient; } - BigQueryWriteClient getMockWriteClient() { - return mockWriteClient; - } - @Override protected StreamWriter createWriter() { StreamWriter writer = mock(StreamWriter.class); @@ -102,6 +102,11 @@ public void tearDown() { pluginLogger.setLevel(originalLevel); } + @Test + public void getGcsOffloader_emptyBucketName_returnsNull() { + assertNull(pluginState.getGcsOffloader(config)); + } + @Test public void addPendingTask_removedTaskOnCompletion() { String invocationId = "testInvocation"; @@ -242,4 +247,38 @@ public void close_succeedsAndCleansUp() throws Exception { assertTrue(pluginState.getTraceManagers().isEmpty()); assertTrue(pluginState.getExecutor().isShutdown()); } + + @Test + public void close_respectsRemainingTimeoutBudget() throws Exception { + config = config.toBuilder().shutdownTimeout(Duration.ofMillis(500)).build(); + pluginState = new TestPluginState(config); + + ExecutorService mockOffloadExecutor = mock(ExecutorService.class); + Field field = PluginState.class.getDeclaredField("offloadExecutor"); + field.setAccessible(true); + field.set(pluginState, mockOffloadExecutor); + + pluginState + .getExecutor() + .execute( + () -> { + try { + Thread.sleep(200); + } catch (InterruptedException e) { + // ignore + } + }); + + when(mockOffloadExecutor.awaitTermination(any(Long.class), any(TimeUnit.class))) + .thenReturn(true); + + pluginState.close().test().awaitDone(2, SECONDS); + + ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Long.class); + verify(mockOffloadExecutor).awaitTermination(timeoutCaptor.capture(), any(TimeUnit.class)); + + long capturedTimeout = timeoutCaptor.getValue(); + assertTrue("Timeout should be less than 400", capturedTimeout < 400); + assertTrue("Timeout should be greater than 100", capturedTimeout > 100); + } } diff --git a/pom.xml b/pom.xml index 73e0ccbaa..7df9a72a9 100644 --- a/pom.xml +++ b/pom.xml @@ -75,6 +75,7 @@ 3.9.0 5.6 4.1.118.Final + 2.9.1 @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.text=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED -Dio.netty.tryReflectionSetAccessible=true @@ -294,6 +295,11 @@ assertj-core ${assertj.version} + + org.apache.tika + tika-core + ${tika.version} +