From 1907aac51e8c6dcb117763fbdc3a0263064234e3 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 8 Apr 2026 03:30:04 -0700 Subject: [PATCH] feat: Add tools and toolset to use SkillSource in ADK agents This change introduces a set of tools and a toolset to enable ADK agents to interact with and use skills loaded via `SkillSource`. Key changes: - ListSkillsTool: Lists available skills. - LoadSkillTool: Loads skill instructions. - LoadSkillResourceTool: Loads skill resources. - SkillToolset: Groups the above tools. - Integrates SkillSource into LlmAgent and BaseLlmFlow. - Tests for all new tools. PiperOrigin-RevId: 896391372 --- .../java/com/google/adk/agents/LlmAgent.java | 46 +++ .../adk/flows/llmflows/BaseLlmFlow.java | 17 +- .../adk/skills/AbstractSkillSource.java | 24 +- .../adk/skills/InMemorySkillSource.java | 15 +- .../google/adk/skills/LocalSkillSource.java | 27 +- .../adk/skills/SkillSourceException.java | 36 +- .../com/google/adk/tools/BaseToolset.java | 10 +- .../adk/tools/skills/ListSkillsTool.java | 59 ++++ .../tools/skills/LoadSkillResourceTool.java | 243 +++++++++++++ .../adk/tools/skills/LoadSkillTool.java | 87 +++++ .../google/adk/tools/skills/SkillToolset.java | 131 +++++++ .../adk/skills/LocalSkillSourceTest.java | 118 +++++++ .../java/com/google/adk/testing/TestLlm.java | 2 +- .../adk/tools/skills/ListSkillsToolTest.java | 151 ++++++++ .../skills/LoadSkillResourceToolTest.java | 330 ++++++++++++++++++ .../adk/tools/skills/LoadSkillToolTest.java | 161 +++++++++ .../adk/tools/skills/SkillToolsetTest.java | 143 ++++++++ 17 files changed, 1561 insertions(+), 39 deletions(-) create mode 100644 core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java create mode 100644 core/src/main/java/com/google/adk/tools/skills/SkillToolset.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java create mode 100644 core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 98bba4606..addbf59d9 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -47,12 +47,16 @@ import com.google.adk.events.Event; import com.google.adk.flows.llmflows.AutoFlow; import com.google.adk.flows.llmflows.BaseLlmFlow; +import com.google.adk.flows.llmflows.RequestProcessor; +import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; import com.google.adk.flows.llmflows.SingleFlow; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmRegistry; +import com.google.adk.models.LlmRequest; import com.google.adk.models.Model; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -70,6 +74,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.Executor; +import java.util.function.BiFunction; import java.util.function.Function; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; @@ -735,6 +740,47 @@ public Flowable canonicalTools(@Nullable ReadonlyContext context) { return Flowable.concat(toolFlowables); } + /** + * Constructs a {@link RequestProcessor} that sequentially applies the {@code processLlmRequest} + * methods of all tools and toolsets associated with this agent to the incoming {@link + * LlmRequest}. + * + * @return A {@link RequestProcessor} that applies tool-specific modifications to LLM requests. + */ + public RequestProcessor getRequestProcessorFromTools() { + return (context, request) -> { + ReadonlyContext readonlyContext = new ReadonlyContext(context); + List> processors = new ArrayList<>(); + + for (Object toolOrToolset : toolsUnion()) { + if (toolOrToolset instanceof BaseTool baseTool) { + processors.add(baseTool::processLlmRequest); + } else if (toolOrToolset instanceof BaseToolset baseToolset) { + // First apply the toolset's own request processor, then unwrap all tools from the toolset + // and apply each individual tool's request processor sequentially. + processors.add( + (builder, ctx) -> + baseToolset + .processLlmRequest(builder, ctx) + .andThen(baseToolset.getTools(readonlyContext)) + .concatMapCompletable(b -> b.processLlmRequest(builder, ctx))); + } else { + throw new IllegalArgumentException( + "Object in tools list is not of a supported type: " + + toolOrToolset.getClass().getName()); + } + } + + LlmRequest.Builder builder = request.toBuilder(); + ToolContext toolContext = ToolContext.builder(context).build(); + return Flowable.fromIterable(processors) + .concatMapCompletable(f -> f.apply(builder, toolContext)) + .andThen( + Single.fromCallable( + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + }; + } + public Instruction instruction() { return instruction; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..c64fc9695 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -25,11 +25,9 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequest; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.ReadonlyContext; import com.google.adk.agents.RunConfig.StreamingMode; import com.google.adk.events.Event; import com.google.adk.flows.BaseFlow; -import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult; import com.google.adk.models.BaseLlm; import com.google.adk.models.BaseLlmConnection; @@ -38,7 +36,6 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.telemetry.Tracing; -import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.genai.types.FunctionResponse; @@ -96,20 +93,8 @@ private Flowable preprocess( Context currentContext = Context.current(); LlmAgent agent = (LlmAgent) context.agent(); - RequestProcessor toolsProcessor = - (ctx, req) -> { - LlmRequest.Builder builder = req.toBuilder(); - return agent - .canonicalTools(new ReadonlyContext(ctx)) - .concatMapCompletable( - tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) - .andThen( - Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); - }; - Iterable allProcessors = - Iterables.concat(requestProcessors, ImmutableList.of(toolsProcessor)); + Iterables.concat(requestProcessors, ImmutableList.of(agent.getRequestProcessorFromTools())); return Flowable.fromIterable(allProcessors) .concatMap( diff --git a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java index aca399f92..31f17a18e 100644 --- a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java +++ b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java @@ -16,6 +16,8 @@ package com.google.adk.skills; +import static com.google.adk.skills.SkillSourceException.SKILL_FORMAT_ERROR; +import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; import static java.nio.channels.Channels.newReader; import static java.nio.charset.StandardCharsets.UTF_8; @@ -82,12 +84,14 @@ private Frontmatter loadFrontmatter(String skillName, PathT skillMdPath) Frontmatter frontmatter = yamlMapper.readValue(yaml, Frontmatter.class); if (!frontmatter.name().equals(skillName)) { throw new SkillSourceException( - "Skill name '%s' does not match directory name '%s'." - .formatted(frontmatter.name(), skillName)); + "Skill name in the frontmatter '%s' does not match skill name '%s'." + .formatted(frontmatter.name(), skillName), + SKILL_LOAD_ERROR); } return frontmatter; } catch (IOException e) { - throw new SkillSourceException("Cannot load frontmatter for skill '" + skillName + "'", e); + throw new SkillSourceException( + "Cannot load frontmatter for skill '" + skillName + "'", SKILL_LOAD_ERROR, e); } } @@ -100,7 +104,9 @@ public Single loadInstructions(String skillName) { return readInstructions(reader); } catch (IOException e) { throw new SkillSourceException( - "Failed to load instruction for skill '" + skillName + "'", e); + "Failed to load instruction for skill '" + skillName + "'", + SKILL_LOAD_ERROR, + e); } }); } @@ -140,7 +146,8 @@ private String readFrontmatterYaml(BufferedReader reader) throws IOException, SkillSourceException { String line = reader.readLine(); if (line == null || !line.trim().equals(THREE_DASHES)) { - throw new SkillSourceException("Skill file must start with " + THREE_DASHES); + throw new SkillSourceException( + "Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR); } StringBuilder sb = new StringBuilder(); @@ -151,14 +158,15 @@ private String readFrontmatterYaml(BufferedReader reader) sb.append(line).append("\n"); } throw new SkillSourceException( - "Skill file frontmatter not properly closed with " + THREE_DASHES); + "Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR); } private String readInstructions(BufferedReader reader) throws IOException, SkillSourceException { // Skip the frontmatter block String line = reader.readLine(); if (line == null || !line.trim().equals(THREE_DASHES)) { - throw new SkillSourceException("Skill file must start with " + THREE_DASHES); + throw new SkillSourceException( + "Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR); } boolean dashClosed = false; while ((line = reader.readLine()) != null) { @@ -169,7 +177,7 @@ private String readInstructions(BufferedReader reader) throws IOException, Skill } if (!dashClosed) { throw new SkillSourceException( - "Skill file frontmatter not properly closed with " + THREE_DASHES); + "Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR); } // Read the instructions till the end of the file StringBuilder sb = new StringBuilder(); diff --git a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java index 42916e36a..d299dfb21 100644 --- a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java +++ b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java @@ -16,6 +16,8 @@ package com.google.adk.skills; +import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND; +import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.charset.StandardCharsets.UTF_8; @@ -56,7 +58,8 @@ public Single> listFrontmatters() { public Single> listResources(String skillName, String resourceDirectory) { SkillData data = skills.get(skillName); if (data == null) { - return Single.error(new SkillSourceException("Skill not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); } String prefix = resourceDirectory.isEmpty() @@ -67,7 +70,8 @@ public Single> listResources(String skillName, String reso && data.resources().keySet().stream().noneMatch(path -> path.startsWith(prefix))) { return Single.error( new SkillSourceException( - "Resource directory not found: " + resourceDirectory + " for skill: " + skillName)); + "Resource directory not found: " + resourceDirectory + " for skill: " + skillName, + RESOURCE_NOT_FOUND)); } return Single.just( @@ -92,13 +96,16 @@ public Single loadResource(String skillName, String resourcePath) { .map(SkillData::resources) .mapOptional(m -> Optional.ofNullable(m.get(resourcePath))) .switchIfEmpty( - Single.error(new SkillSourceException("Resource not found: " + resourcePath))); + Single.error( + new SkillSourceException( + "Resource not found: " + resourcePath, RESOURCE_NOT_FOUND))); } private Single getSkillData(String skillName) { SkillData data = skills.get(skillName); if (data == null) { - return Single.error(new SkillSourceException("Skill not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); } return Single.just(data); } diff --git a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java index 939c30b3c..b4b7d4876 100644 --- a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java +++ b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java @@ -16,6 +16,10 @@ package com.google.adk.skills; +import static com.google.adk.skills.SkillSourceException.RESOURCE_LOAD_ERROR; +import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND; +import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; +import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.file.Files.isDirectory; @@ -43,14 +47,16 @@ public LocalSkillSource(Path skillsBasePath) { public Single> listResources(String skillName, String resourceDirectory) { Path skillDir = skillsBasePath.resolve(skillName); if (!isDirectory(skillDir)) { - return Single.error(new SkillSourceException("Skill not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND)); } Path resourceDir = skillDir.resolve(resourceDirectory); if (!isDirectory(resourceDir)) { return Single.error( new SkillSourceException( "Resource directory '%s' not found for skill '%s'" - .formatted(resourceDirectory, skillName))); + .formatted(resourceDirectory, skillName), + RESOURCE_NOT_FOUND)); } return Single.fromCallable( @@ -67,7 +73,9 @@ public Single> listResources(String skillName, String reso t -> Single.error( new SkillSourceException( - "Failed to traverse resource directory: " + resourceDirectory, t))); + "Failed to traverse resource directory: " + resourceDirectory, + RESOURCE_LOAD_ERROR, + t))); } @Override @@ -78,7 +86,9 @@ protected Flowable listSkills() { t -> Flowable.error( new SkillSourceException( - "Failed to list skills in directory: " + skillsBasePath, t))) + "Failed to list skills in directory: " + skillsBasePath, + SKILL_LOAD_ERROR, + t))) .filter(Files::isDirectory) .mapOptional(this::findSkillMd) .map(skillMd -> new SkillMdPath(skillMd.getParent().getFileName().toString(), skillMd)); @@ -88,7 +98,8 @@ protected Flowable listSkills() { protected Single findResourcePath(String skillName, String resourcePath) { Path file = skillsBasePath.resolve(skillName).resolve(resourcePath); if (!Files.exists(file)) { - return Single.error(new SkillSourceException("Resource not found: " + file)); + return Single.error( + new SkillSourceException("Resource not found: " + file, RESOURCE_NOT_FOUND)); } return Single.just(file); } @@ -97,11 +108,13 @@ protected Single findResourcePath(String skillName, String resourcePath) { protected Single findSkillMdPath(String skillName) { Path skillDir = skillsBasePath.resolve(skillName); if (!isDirectory(skillDir)) { - return Single.error(new SkillSourceException("Skill directory not found: " + skillName)); + return Single.error( + new SkillSourceException("Skill directory not found: " + skillName, SKILL_NOT_FOUND)); } return Maybe.fromOptional(findSkillMd(skillDir)) .switchIfEmpty( - Single.error(new SkillSourceException("SKILL.md not found in " + skillName))); + Single.error( + new SkillSourceException("SKILL.md not found in " + skillName, SKILL_NOT_FOUND))); } @Override diff --git a/core/src/main/java/com/google/adk/skills/SkillSourceException.java b/core/src/main/java/com/google/adk/skills/SkillSourceException.java index be23291da..273428897 100644 --- a/core/src/main/java/com/google/adk/skills/SkillSourceException.java +++ b/core/src/main/java/com/google/adk/skills/SkillSourceException.java @@ -22,11 +22,43 @@ */ public final class SkillSourceException extends Exception { - public SkillSourceException(String message) { + public static final String SKILL_LOAD_ERROR = "SKILL_LOAD_ERROR"; + public static final String SKILL_NOT_FOUND = "SKILL_NOT_FOUND"; + public static final String SKILL_FORMAT_ERROR = "SKILL_FORMAT_ERROR"; + public static final String RESOURCE_LOAD_ERROR = "RESOURCE_LOAD_ERROR"; + public static final String RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"; + + private final String errorCode; + + /** + * Constructs a new exception with the specified detail message and error code. + * + * @param message The detail message. + * @param errorCode The specific error code categorizing the failure. + */ + public SkillSourceException(String message, String errorCode) { super(message); + this.errorCode = errorCode; } - public SkillSourceException(String message, Throwable cause) { + /** + * Constructs a new exception with the specified detail message, error code, and cause. + * + * @param message The detail message. + * @param errorCode The specific error code categorizing the failure. + * @param cause The cause. + */ + public SkillSourceException(String message, String errorCode, Throwable cause) { super(message, cause); + this.errorCode = errorCode; + } + + /** + * Returns the error code categorizing the failure. + * + * @return The error code string. + */ + public String getErrorCode() { + return errorCode; } } diff --git a/core/src/main/java/com/google/adk/tools/BaseToolset.java b/core/src/main/java/com/google/adk/tools/BaseToolset.java index 76369e5b9..84a5d8fc2 100644 --- a/core/src/main/java/com/google/adk/tools/BaseToolset.java +++ b/core/src/main/java/com/google/adk/tools/BaseToolset.java @@ -17,6 +17,8 @@ package com.google.adk.tools; import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import java.util.List; import org.jspecify.annotations.Nullable; @@ -24,11 +26,17 @@ /** Base interface for toolsets. */ public interface BaseToolset extends AutoCloseable { + /** Processes the outgoing {@link LlmRequest.Builder}. */ + default Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.complete(); + } + /** * Return all tools in the toolset based on the provided context. * * @param readonlyContext Context used to filter tools available to the agent. - * @return A Single emitting a list of tools available under the specified context. + * @return A Flowable emitting tools available under the specified context. */ Flowable getTools(ReadonlyContext readonlyContext); diff --git a/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java b/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java new file mode 100644 index 000000000..bc669632a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/ListSkillsTool.java @@ -0,0 +1,59 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; + +/** Tool to list all available skills. */ +final class ListSkillsTool extends BaseTool { + private final SkillSource skillSource; + + ListSkillsTool(SkillSource skillSource) { + super("list_skills", "Lists all available skills with their names and descriptions."); + this.skillSource = skillSource; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters( + Schema.builder().type(Type.Known.OBJECT).properties(ImmutableMap.of()).build()) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + return skillSource + .listFrontmatters() + .map(ImmutableMap::values) + .map(SkillToolset::getSkillsPrompt) + .>map(skills -> ImmutableMap.of("skills_xml", skills)) + .onErrorResumeNext(SkillToolset::createErrorResponse); + } +} diff --git a/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java b/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java new file mode 100644 index 000000000..6f1e7e042 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/LoadSkillResourceTool.java @@ -0,0 +1,243 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.tools.skills.SkillToolset.createErrorResponse; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.net.URLConnection.guessContentTypeFromName; +import static java.net.URLConnection.guessContentTypeFromStream; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.adk.models.LlmRequest; +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.io.ByteSource; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; + +/** Tool to load resources (references, assets, or scripts) from a skill. */ +final class LoadSkillResourceTool extends BaseTool { + + private static final ImmutableSet EXTRA_TEXT_MIME_TYPES = + ImmutableSet.of( + // go/keep-sorted start + "application/json", + "application/x-python", + "application/x-sh", + "application/x-shar", + "application/x-shellscript", + "application/xml", + "application/yaml" + // go/keep-sorted end + ); + private static final String BINARY_FILE_DETECTED_MSG = + "Binary file detected. The content has been included in the next part of the function" + + " response for you to analyze."; + private static final String SKILL_NAME = "skill_name"; + private static final String FILE_PATH = "file_path"; + private static final String CONTENT = "content"; + private static final String MIME_TYPE = "mime_type"; + + private final SkillSource skillSource; + + LoadSkillResourceTool(SkillSource skillSource) { + super( + "load_skill_resource", + "Loads a resource file (from references/, assets/, or scripts/) from within a skill."); + this.skillSource = skillSource; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters( + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + SKILL_NAME, + Schema.builder() + .type(Type.Known.STRING) + .description("The name of the skill.") + .build(), + FILE_PATH, + Schema.builder() + .type(Type.Known.STRING) + .description( + "The relative path to the resource (e.g.," + + " 'references/my_doc.md', 'assets/template.txt'," + + " or 'scripts/setup.sh').") + .build())) + .required(ImmutableList.of(SKILL_NAME, FILE_PATH)) + .build()) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + String skillName = (String) args.get(SKILL_NAME); + String resourcePath = (String) args.get(FILE_PATH); + + if (Strings.isNullOrEmpty(skillName)) { + return createErrorResponse("Skill name is required.", "MISSING_SKILL_NAME"); + } + if (Strings.isNullOrEmpty(resourcePath)) { + return createErrorResponse("Resource path is required.", "MISSING_RESOURCE_PATH"); + } + if (!resourcePath.startsWith("references/") + && !resourcePath.startsWith("assets/") + && !resourcePath.startsWith("scripts/")) { + return createErrorResponse( + "Path must start with 'references/', 'assets/', or 'scripts/'.", "INVALID_RESOURCE_PATH"); + } + + return skillSource + .loadResource(skillName, resourcePath) + .>map( + contentSource -> createResult(skillName, resourcePath, contentSource)) + .onErrorResumeNext(SkillToolset::createErrorResponse); + } + + private boolean hasBinaryContentResponse(FunctionResponse functionResponse) { + return functionResponse + .response() + .filter( + resp -> + resp.containsKey(SKILL_NAME) + && resp.containsKey(MIME_TYPE) + && resp.get(CONTENT) instanceof byte[]) + .isPresent(); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return super.processLlmRequest(llmRequestBuilder, toolContext) + .andThen( + Completable.fromRunnable( + () -> { + List contents = new ArrayList<>(llmRequestBuilder.build().contents()); + if (contents.isEmpty()) { + return; + } + + Content lastContent = Iterables.getLast(contents); + List parts = lastContent.parts().orElse(ImmutableList.of()); + + // Extract raw binary content into a dedicated binary Part + ImmutableList newParts = + parts.stream().flatMap(this::processPart).collect(toImmutableList()); + + if (!newParts.isEmpty()) { + contents.replaceAll( + content -> + content == lastContent + ? content.toBuilder().parts(newParts).build() + : content); + + llmRequestBuilder.contents(contents); + } + })); + } + + /** + * Processes a {@link Part} to extract raw binary content from a function response. + * + *

If the part is a function response from this tool containing binary data, it returns a + * stream containing the updated function response part (with a placeholder message) and a new + * part containing the raw binary data. Otherwise, it returns an empty stream. + * + * @param part the {@link Part} to process + * @return a stream containing the processed parts, or an empty stream if the part does not + * contain a binary function response from this tool + */ + private Stream processPart(Part part) { + return part + .functionResponse() + .filter(funcResp -> funcResp.name().orElse("").equals(name())) + .filter(this::hasBinaryContentResponse) + .stream() + .flatMap( + funcResp -> + funcResp.response().stream() + .flatMap( + response -> { + Map newResponse = new HashMap<>(response); + + String mimeType = newResponse.remove(MIME_TYPE).toString(); + byte[] binaryContent = + (byte[]) newResponse.replace(CONTENT, BINARY_FILE_DETECTED_MSG); + + Part updatedPart = + part.toBuilder() + .functionResponse(funcResp.toBuilder().response(newResponse)) + .build(); + Part binaryPart = Part.fromBytes(binaryContent, mimeType); + + return Stream.of(updatedPart, binaryPart); + })); + } + + private ImmutableMap createResult( + String skillName, String resourcePath, ByteSource contentSource) throws IOException { + byte[] bytes = contentSource.read(); + // Special handling of shell script as the guessContentTypeFromName would return + // application/x-shar + String contentType = + resourcePath.endsWith(".sh") || resourcePath.endsWith(".bash") + ? "application/x-sh" + : guessContentTypeFromName(resourcePath); + if (contentType == null) { + contentType = guessContentTypeFromStream(new ByteArrayInputStream(bytes)); + } + if (contentType == null) { + contentType = "application/octet-stream"; + } + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put(SKILL_NAME, skillName).put(FILE_PATH, resourcePath).put(MIME_TYPE, contentType); + + if (contentType.startsWith("text/") || EXTRA_TEXT_MIME_TYPES.contains(contentType)) { + builder.put(CONTENT, new String(bytes, UTF_8)); + } else { + builder.put(CONTENT, bytes); + } + return builder.buildOrThrow(); + } +} diff --git a/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java b/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java new file mode 100644 index 000000000..6d47e297a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/LoadSkillTool.java @@ -0,0 +1,87 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.tools.skills.SkillToolset.createErrorResponse; + +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.Schema; +import com.google.genai.types.Type; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import java.util.Optional; + +/** Tool to load a skill's instructions. */ +final class LoadSkillTool extends BaseTool { + + private static final String SKILL_NAME = "skill_name"; + private final SkillSource skillSource; + + LoadSkillTool(SkillSource skillSource) { + super("load_skill", "Loads the SKILL.md instructions for a given skill."); + this.skillSource = skillSource; + } + + @Override + public Optional declaration() { + return Optional.of( + FunctionDeclaration.builder() + .name(name()) + .description(description()) + .parameters( + Schema.builder() + .type(Type.Known.OBJECT) + .properties( + ImmutableMap.of( + SKILL_NAME, + Schema.builder() + .type(Type.Known.STRING) + .description("The name of the skill to load.") + .build())) + .required(ImmutableList.of(SKILL_NAME)) + .build()) + .build()); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + String skillName = (String) args.get(SKILL_NAME); + if (Strings.isNullOrEmpty(skillName)) { + return createErrorResponse("Skill name is required.", "MISSING_SKILL_NAME"); + } + + return skillSource + .loadFrontmatter(skillName) + .>zipWith( + skillSource.loadInstructions(skillName), + (frontmatter, instructions) -> + ImmutableMap.of( + "skill_name", + skillName, + "frontmatter", + frontmatter, + "instructions", + instructions)) + .onErrorResumeNext(SkillToolset::createErrorResponse); + } +} diff --git a/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java b/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java new file mode 100644 index 000000000..baf54381a --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/skills/SkillToolset.java @@ -0,0 +1,131 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static java.util.Optional.ofNullable; + +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.SkillSource; +import com.google.adk.skills.SkillSourceException; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.Collection; +import java.util.Map; +import java.util.StringJoiner; + +/** + * A toolset for managing and interacting with agent skills. Provides tools to list, load, and run + * skills. + */ +public class SkillToolset implements BaseToolset { + + private static final String DEFAULT_SKILL_SYSTEM_INSTRUCTION = + """ + You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. + + Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: + - **SKILL.md** (required): The main instruction file with skill metadata and detailed markdown instructions. + - **references/** (Optional): Additional documentation or examples for skill usage. + - **assets/** (Optional): Templates, scripts or other resources used by the skill. + - **scripts/** (Optional): Executable scripts that can be run via bash. + + This is very important: + + 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `skill_name=""` to read its full instructions before proceeding. + 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. + 3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. + 4. Use `run_skill_script` to run scripts from a skill's `scripts/` directory. Use `load_skill_resource` to view script content first if needed. + """; + + private final SkillSource skillSource; + private final ImmutableList coreTools; + private final String systemInstruction; + + /** Initializes the SkillToolset with a SkillSource and default execution settings. */ + public SkillToolset(SkillSource skillSource) { + this(skillSource, DEFAULT_SKILL_SYSTEM_INSTRUCTION); + } + + /** Initializes the SkillToolset with a SkillSource. */ + public SkillToolset(SkillSource skillSource, String systemInstruction) { + this.skillSource = skillSource; + this.systemInstruction = systemInstruction; + this.coreTools = + ImmutableList.of( + new ListSkillsTool(skillSource), + new LoadSkillTool(skillSource), + new LoadSkillResourceTool(skillSource)); + } + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.fromIterable(ImmutableList.builder().addAll(coreTools).build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromSingle( + skillSource + .listFrontmatters() + .map(ImmutableMap::values) + .map(SkillToolset::getSkillsPrompt) + .map( + skills -> + llmRequestBuilder.appendInstructions( + ImmutableList.of(systemInstruction, skills)))); + } + + @Override + public void close() throws Exception { + // No resources to release for now + } + + static Single> createErrorResponse(String errorMessage, String errorCode) { + return Single.just(ImmutableMap.of("error", errorMessage, "error_code", errorCode)); + } + + static Single> createErrorResponse(Throwable t) { + if (t instanceof SkillSourceException ex) { + return Single.just( + ImmutableMap.of( + "error", + ofNullable(ex.getMessage()).orElse(ex.toString()), + "error_code", + ex.getErrorCode())); + } + return Single.error(t); + } + + static String getSkillsPrompt(Collection frontmatters) { + return frontmatters.stream() + .map(Frontmatter::toXml) + .reduce( + new StringJoiner("\n", "", "").setEmptyValue(""), + StringJoiner::add, + StringJoiner::merge) + .toString(); + } +} diff --git a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java index 256f1d66a..ee6684f02 100644 --- a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java +++ b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java @@ -250,4 +250,122 @@ public void testListSkillMdPaths_skillSourceException() throws IOException { RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); } + + @Test + public void testLoadFrontmatter_missingStartDashes() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + name: my-skill + description: This is a test skill + --- + body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } + + @Test + public void testLoadInstructions_missingStartDashes() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + name: my-skill + description: Test + --- + Some Markdown Body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadInstructions("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } + + @Test + public void testLoadFrontmatter_nameMismatch() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + --- + name: other-skill + description: This is a test skill + --- + body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains( + "Skill name in the frontmatter 'other-skill' does not match skill name 'my-skill'."); + } + + @Test + public void testLoadFrontmatter_emptyFile() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString(skillDir.resolve("SKILL.md"), ""); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } + + @Test + public void testLoadInstructions_emptyFile() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString(skillDir.resolve("SKILL.md"), ""); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadInstructions("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file must start with ---"); + } } diff --git a/core/src/test/java/com/google/adk/testing/TestLlm.java b/core/src/test/java/com/google/adk/testing/TestLlm.java index aaacf00a0..fc9ce3850 100644 --- a/core/src/test/java/com/google/adk/testing/TestLlm.java +++ b/core/src/test/java/com/google/adk/testing/TestLlm.java @@ -42,7 +42,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; import java.util.function.Supplier; -import javax.annotation.Nullable; +import org.jspecify.annotations.Nullable; /** * A test implementation of {@link BaseLlm}. diff --git a/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java b/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java new file mode 100644 index 000000000..fe5b202a2 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/ListSkillsToolTest.java @@ -0,0 +1,151 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.sessions.Session; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.skills.SkillSourceException; +import com.google.adk.testing.TestBaseAgent; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ListSkillsToolTest { + + @Test + public void call_listSkillsTool_success() { + Frontmatter testFrontmatter = + Frontmatter.builder().name("test-skill").description("test skill").build(); + + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = + InMemorySkillSource.builder() + .skill(testFrontmatter.name()) + .frontmatter(testFrontmatter) + .instructions("Test instructions") + .build(); + ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); + Map response = + listSkillsTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skills_xml", "" + testFrontmatter.toXml() + ""); + } + + @Test + public void call_listSkillsTool_empty() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + ListSkillsTool listSkillsTool = new ListSkillsTool(InMemorySkillSource.builder().build()); + Map response = + listSkillsTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response).containsExactly("skills_xml", ""); + } + + @Test + public void call_listSkillsTool_skillSourceException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + when(skillSource.listFrontmatters()) + .thenReturn( + Single.error(new SkillSourceException("Failed to list skills", SKILL_LOAD_ERROR))); + + ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); + Map response = + listSkillsTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly("error", "Failed to list skills", "error_code", "SKILL_LOAD_ERROR"); + } + + @Test + public void call_listSkillsTool_otherException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + RuntimeException expectedException = new RuntimeException("Unexpected error"); + when(skillSource.listFrontmatters()).thenReturn(Single.error(expectedException)); + + ListSkillsTool listSkillsTool = new ListSkillsTool(skillSource); + var single = + listSkillsTool.runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()); + + RuntimeException thrown = assertThrows(RuntimeException.class, single::blockingGet); + + assertThat(thrown).hasMessageThat().contains("Unexpected error"); + } + + @Test + public void call_listSkillsTool_declaration() { + ListSkillsTool listSkillsTool = new ListSkillsTool(mock(SkillSource.class)); + assertThat(listSkillsTool.declaration().get().name()).hasValue("list_skills"); + } +} diff --git a/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java b/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java new file mode 100644 index 000000000..3ab9b2a17 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/LoadSkillResourceToolTest.java @@ -0,0 +1,330 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.Session; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.testing.TestBaseAgent; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class LoadSkillResourceToolTest { + + @Test + public void call_loadSkillResourceTool_reference_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/my_doc.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "references/my_doc.md", + "mime_type", "text/markdown", + "content", "doc content"); + } + + @Test + public void call_loadSkillResourceTool_asset_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "assets/template.txt"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "assets/template.txt", + "mime_type", "text/plain", + "content", "asset content"); + } + + @Test + public void call_loadSkillResourceTool_script_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "scripts/setup.sh"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "scripts/setup.sh", + "mime_type", "application/x-sh", + "content", "echo hello"); + } + + @Test + public void call_loadSkillResourceTool_streamDetection_success() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "assets/data_no_ext"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", "test-skill", + "file_path", "assets/data_no_ext", + "mime_type", "application/xml", + "content", ""); + } + + @Test + public void call_loadSkillResourceTool_binaryReference_detected() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + ToolContext toolContext = createToolContext(); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/binary.dat"), + toolContext) + .blockingGet(); + + Part partFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(toolContext.functionCallId().orElse("")) + .name(loadSkillResourceTool.name()) + .response(response) + .build()) + .build(); + + // Binary data is added as separate part in the next request to LLM + LlmRequest.Builder builder = + LlmRequest.builder() + .contents( + ImmutableList.of( + Content.builder().role("user").parts(partFunctionResponse).build())); + loadSkillResourceTool.processLlmRequest(builder, toolContext).blockingAwait(); + + List contents = builder.build().contents(); + assertThat(contents).hasSize(1); + List parts = contents.get(0).parts().get(); + assertThat(parts).hasSize(2); + + FunctionResponse updatedFunctionResponse = parts.get(0).functionResponse().get(); + assertThat(updatedFunctionResponse.response().get()) + .containsExactly( + "skill_name", + "test-skill", + "file_path", + "references/binary.dat", + "content", + "Binary file detected. The content has been included in the next part of the function" + + " response for you to analyze."); + + Part binaryPart = parts.get(1); + assertThat(binaryPart.inlineData().get().mimeType()).hasValue("application/octet-stream"); + assertThat(binaryPart.inlineData().get().data().get()).isEqualTo(new byte[] {0, 1, 2, 3}); + } + + @Test + public void call_loadSkillResourceTool_nonBinaryReference_notChanged() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + ToolContext toolContext = createToolContext(); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/my_doc.md"), + toolContext) + .blockingGet(); + + Part partFunctionResponse = + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(toolContext.functionCallId().orElse("")) + .name(loadSkillResourceTool.name()) + .response(response) + .build()) + .build(); + + LlmRequest.Builder builder = + LlmRequest.builder() + .contents( + ImmutableList.of( + Content.builder().role("user").parts(partFunctionResponse).build())); + List expectedContents = builder.build().contents(); + + loadSkillResourceTool.processLlmRequest(builder, toolContext).blockingAwait(); + + assertThat(builder.build().contents()).isEqualTo(expectedContents); + } + + @Test + public void call_loadSkillResourceTool_missingSkillName() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync(ImmutableMap.of("file_path", "references/my_doc.md"), createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Skill name is required.", + "error_code", "MISSING_SKILL_NAME"); + } + + @Test + public void call_loadSkillResourceTool_missingPath() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync(ImmutableMap.of("skill_name", "test-skill"), createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Resource path is required.", + "error_code", "MISSING_RESOURCE_PATH"); + } + + @Test + public void call_loadSkillResourceTool_skillNotFound() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "other-skill", "file_path", "references/my_doc.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Skill not found: other-skill", + "error_code", "SKILL_NOT_FOUND"); + } + + @Test + public void call_loadSkillResourceTool_invalidPathPrefix() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "invalid/my_doc.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Path must start with 'references/', 'assets/', or 'scripts/'.", + "error_code", "INVALID_RESOURCE_PATH"); + } + + @Test + public void call_loadSkillResourceTool_resourceNotFound() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(createTestSkillSource()); + Map response = + loadSkillResourceTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill", "file_path", "references/missing.md"), + createToolContext()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "error", "Resource not found: references/missing.md", + "error_code", "RESOURCE_NOT_FOUND"); + } + + @Test + public void call_loadSkillResourceTool_declaration() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(mock(SkillSource.class)); + assertThat(loadSkillResourceTool.declaration().get().name()).hasValue("load_skill_resource"); + } + + @Test + public void call_loadSkillResourceTool_processLlmRequest_emptyContents() { + LoadSkillResourceTool loadSkillResourceTool = + new LoadSkillResourceTool(mock(SkillSource.class)); + LlmRequest.Builder builder = LlmRequest.builder().contents(ImmutableList.of()); + loadSkillResourceTool.processLlmRequest(builder, createToolContext()).blockingAwait(); + assertThat(builder.build().contents()).isEmpty(); + } + + private ToolContext createToolContext() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + return ToolContext.builder(invocationContext).build(); + } + + private SkillSource createTestSkillSource() { + return InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .addResource("references/my_doc.md", "doc content".getBytes(UTF_8)) + .addResource("references/binary.dat", new byte[] {0, 1, 2, 3}) + .addResource("assets/template.txt", "asset content".getBytes(UTF_8)) + .addResource("scripts/setup.sh", "echo hello".getBytes(UTF_8)) + .addResource("assets/data_no_ext", "".getBytes(UTF_8)) + .build(); + } +} diff --git a/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java b/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java new file mode 100644 index 000000000..29b127755 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/LoadSkillToolTest.java @@ -0,0 +1,161 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.sessions.Session; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.skills.SkillSourceException; +import com.google.adk.testing.TestBaseAgent; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class LoadSkillToolTest { + + @Test + public void call_loadSkillTool_success() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + Map response = + loadSkillTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill"), + ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly( + "skill_name", + "test-skill", + "instructions", + "Test instructions", + "frontmatter", + Frontmatter.builder().name("test-skill").description("test skill").build()); + } + + @Test + public void call_loadSkillTool_missingSkillName() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + Map response = + loadSkillTool + .runAsync(ImmutableMap.of(), ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response) + .containsExactly("error", "Skill name is required.", "error_code", "MISSING_SKILL_NAME"); + } + + @Test + public void call_loadSkillTool_skillSourceException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + when(skillSource.loadFrontmatter("test-skill")) + .thenReturn(Single.error(new SkillSourceException("Skill not found", SKILL_NOT_FOUND))); + when(skillSource.loadInstructions("test-skill")).thenReturn(Single.just("instructions")); + + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + Map response = + loadSkillTool + .runAsync( + ImmutableMap.of("skill_name", "test-skill"), + ToolContext.builder(invocationContext).build()) + .blockingGet(); + + assertThat(response).containsExactly("error", "Skill not found", "error_code", SKILL_NOT_FOUND); + } + + @Test + public void call_loadSkillTool_otherException() { + TestBaseAgent testAgent = + new TestBaseAgent( + "test agent", "test agent", ImmutableList.of(), ImmutableList.of(), Flowable::empty); + Session session = Session.builder("session").build(); + + InvocationContext invocationContext = mock(InvocationContext.class); + when(invocationContext.agent()).thenReturn(testAgent); + when(invocationContext.session()).thenReturn(session); + + SkillSource skillSource = mock(SkillSource.class); + RuntimeException expectedException = new RuntimeException("Unexpected error"); + when(skillSource.loadFrontmatter("test-skill")).thenReturn(Single.error(expectedException)); + when(skillSource.loadInstructions("test-skill")).thenReturn(Single.just("instructions")); + + LoadSkillTool loadSkillTool = new LoadSkillTool(skillSource); + var single = + loadSkillTool.runAsync( + ImmutableMap.of("skill_name", "test-skill"), + ToolContext.builder(invocationContext).build()); + + RuntimeException thrown = assertThrows(RuntimeException.class, single::blockingGet); + + assertThat(thrown).hasMessageThat().contains("Unexpected error"); + } + + @Test + public void call_loadSkillTool_declaration() { + LoadSkillTool loadSkillTool = new LoadSkillTool(mock(SkillSource.class)); + assertThat(loadSkillTool.declaration().get().name()).hasValue("load_skill"); + } +} diff --git a/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java b/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java new file mode 100644 index 000000000..4be781469 --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/skills/SkillToolsetTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.skills; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; + +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.skills.Frontmatter; +import com.google.adk.skills.InMemorySkillSource; +import com.google.adk.skills.SkillSource; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.truth.Correspondence; +import io.reactivex.rxjava3.core.Flowable; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SkillToolsetTest { + + @Test + public void getTools_returnsCoreTools() throws Exception { + SkillSource mockSkillSource = mock(SkillSource.class); + try (SkillToolset toolSet = new SkillToolset(mockSkillSource)) { + Flowable tools = toolSet.getTools(null); + List baseTools = tools.toList().blockingGet(); + + assertThat(baseTools) + .comparingElementsUsing(Correspondence.transforming(BaseTool::name, "Tool name")) + .containsExactly("list_skills", "load_skill", "load_skill_resource"); + } + } + + @Test + public void getTools_withInMemorySkills() throws Exception { + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + try (SkillToolset toolSet = new SkillToolset(skillSource)) { + + Flowable tools = toolSet.getTools(null); + List baseTools = tools.toList().blockingGet(); + + assertThat(baseTools) + .comparingElementsUsing(Correspondence.transforming(BaseTool::name, "Tool name")) + .containsExactly("list_skills", "load_skill", "load_skill_resource"); + } + } + + @Test + public void processLlmRequest_addsInstructions() throws Exception { + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + try (SkillToolset toolSet = new SkillToolset(skillSource)) { + + LlmRequest.Builder requestBuilder = LlmRequest.builder(); + ToolContext mockToolContext = mock(ToolContext.class); + + toolSet.processLlmRequest(requestBuilder, mockToolContext).blockingAwait(); + + LlmRequest request = requestBuilder.build(); + ImmutableList instructions = request.getSystemInstructions(); + + assertThat(instructions).isNotEmpty(); + String instruction = instructions.get(0); + assertThat(instruction) + .contains("You can use specialized 'skills' to help you with complex tasks"); + assertThat(instruction).contains(""); + assertThat(instruction).contains("test-skill"); + } + } + + @Test + public void processLlmRequest_withCustomSystemInstruction_addsCustomInstructions() + throws Exception { + SkillSource skillSource = + InMemorySkillSource.builder() + .skill("test-skill") + .frontmatter(Frontmatter.builder().name("test-skill").description("test skill").build()) + .instructions("Test instructions") + .build(); + String customInstruction = "Custom system instruction for testing."; + try (SkillToolset toolSet = new SkillToolset(skillSource, customInstruction)) { + + LlmRequest.Builder requestBuilder = LlmRequest.builder(); + ToolContext mockToolContext = mock(ToolContext.class); + + toolSet.processLlmRequest(requestBuilder, mockToolContext).blockingAwait(); + + LlmRequest request = requestBuilder.build(); + ImmutableList instructions = request.getSystemInstructions(); + + assertThat(instructions).isNotEmpty(); + String instruction = instructions.get(0); + assertThat(instruction).contains(customInstruction); + assertThat(instruction).contains(""); + assertThat(instruction).contains("test-skill"); + } + } + + @Test + public void baseToolset_defaultProcessLlmRequest() throws Exception { + try (BaseToolset baseToolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext context) { + return Flowable.empty(); + } + + @Override + public void close() {} + }) { + baseToolset.processLlmRequest(LlmRequest.builder(), mock(ToolContext.class)).blockingAwait(); + } + } +}