Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@
import com.google.adk.events.Event;
import com.google.adk.flows.llmflows.AutoFlow;
import com.google.adk.flows.llmflows.BaseLlmFlow;
import com.google.adk.flows.llmflows.RequestProcessor;
import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult;
import com.google.adk.flows.llmflows.SingleFlow;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.LlmRequest;
import com.google.adk.models.Model;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.ToolContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand All @@ -70,6 +74,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
Expand Down Expand Up @@ -735,6 +740,47 @@ public Flowable<BaseTool> canonicalTools(@Nullable ReadonlyContext context) {
return Flowable.concat(toolFlowables);
}

/**
* Constructs a {@link RequestProcessor} that sequentially applies the {@code processLlmRequest}
* methods of all tools and toolsets associated with this agent to the incoming {@link
* LlmRequest}.
*
* @return A {@link RequestProcessor} that applies tool-specific modifications to LLM requests.
*/
public RequestProcessor getRequestProcessorFromTools() {
return (context, request) -> {
ReadonlyContext readonlyContext = new ReadonlyContext(context);
List<BiFunction<LlmRequest.Builder, ToolContext, Completable>> processors = new ArrayList<>();

for (Object toolOrToolset : toolsUnion()) {
if (toolOrToolset instanceof BaseTool baseTool) {
processors.add(baseTool::processLlmRequest);
} else if (toolOrToolset instanceof BaseToolset baseToolset) {
// First apply the toolset's own request processor, then unwrap all tools from the toolset
// and apply each individual tool's request processor sequentially.
processors.add(
(builder, ctx) ->
baseToolset
.processLlmRequest(builder, ctx)
.andThen(baseToolset.getTools(readonlyContext))
.concatMapCompletable(b -> b.processLlmRequest(builder, ctx)));
} else {
throw new IllegalArgumentException(
"Object in tools list is not of a supported type: "
+ toolOrToolset.getClass().getName());
}
}

LlmRequest.Builder builder = request.toBuilder();
ToolContext toolContext = ToolContext.builder(context).build();
return Flowable.fromIterable(processors)
.concatMapCompletable(f -> f.apply(builder, toolContext))
.andThen(
Single.fromCallable(
() -> RequestProcessingResult.create(builder.build(), ImmutableList.of())));
};
}

public Instruction instruction() {
return instruction;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LiveRequest;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.ReadonlyContext;
import com.google.adk.agents.RunConfig.StreamingMode;
import com.google.adk.events.Event;
import com.google.adk.flows.BaseFlow;
import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult;
import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.BaseLlmConnection;
Expand All @@ -38,7 +36,6 @@
import com.google.adk.models.LlmRequest;
import com.google.adk.models.LlmResponse;
import com.google.adk.telemetry.Tracing;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.genai.types.FunctionResponse;
Expand Down Expand Up @@ -96,20 +93,8 @@ private Flowable<Event> preprocess(
Context currentContext = Context.current();
LlmAgent agent = (LlmAgent) context.agent();

RequestProcessor toolsProcessor =
(ctx, req) -> {
LlmRequest.Builder builder = req.toBuilder();
return agent
.canonicalTools(new ReadonlyContext(ctx))
.concatMapCompletable(
tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build()))
.andThen(
Single.fromCallable(
() -> RequestProcessingResult.create(builder.build(), ImmutableList.of())));
};

Iterable<RequestProcessor> allProcessors =
Iterables.concat(requestProcessors, ImmutableList.of(toolsProcessor));
Iterables.concat(requestProcessors, ImmutableList.of(agent.getRequestProcessorFromTools()));

return Flowable.fromIterable(allProcessors)
.concatMap(
Expand Down
24 changes: 16 additions & 8 deletions core/src/main/java/com/google/adk/skills/AbstractSkillSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.adk.skills;

import static com.google.adk.skills.SkillSourceException.SKILL_FORMAT_ERROR;
import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR;
import static java.nio.channels.Channels.newReader;
import static java.nio.charset.StandardCharsets.UTF_8;

Expand Down Expand Up @@ -82,12 +84,14 @@ private Frontmatter loadFrontmatter(String skillName, PathT skillMdPath)
Frontmatter frontmatter = yamlMapper.readValue(yaml, Frontmatter.class);
if (!frontmatter.name().equals(skillName)) {
throw new SkillSourceException(
"Skill name '%s' does not match directory name '%s'."
.formatted(frontmatter.name(), skillName));
"Skill name in the frontmatter '%s' does not match skill name '%s'."
.formatted(frontmatter.name(), skillName),
SKILL_LOAD_ERROR);
}
return frontmatter;
} catch (IOException e) {
throw new SkillSourceException("Cannot load frontmatter for skill '" + skillName + "'", e);
throw new SkillSourceException(
"Cannot load frontmatter for skill '" + skillName + "'", SKILL_LOAD_ERROR, e);
}
}

Expand All @@ -100,7 +104,9 @@ public Single<String> loadInstructions(String skillName) {
return readInstructions(reader);
} catch (IOException e) {
throw new SkillSourceException(
"Failed to load instruction for skill '" + skillName + "'", e);
"Failed to load instruction for skill '" + skillName + "'",
SKILL_LOAD_ERROR,
e);
}
});
}
Expand Down Expand Up @@ -140,7 +146,8 @@ private String readFrontmatterYaml(BufferedReader reader)
throws IOException, SkillSourceException {
String line = reader.readLine();
if (line == null || !line.trim().equals(THREE_DASHES)) {
throw new SkillSourceException("Skill file must start with " + THREE_DASHES);
throw new SkillSourceException(
"Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR);
}

StringBuilder sb = new StringBuilder();
Expand All @@ -151,14 +158,15 @@ private String readFrontmatterYaml(BufferedReader reader)
sb.append(line).append("\n");
}
throw new SkillSourceException(
"Skill file frontmatter not properly closed with " + THREE_DASHES);
"Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR);
}

private String readInstructions(BufferedReader reader) throws IOException, SkillSourceException {
// Skip the frontmatter block
String line = reader.readLine();
if (line == null || !line.trim().equals(THREE_DASHES)) {
throw new SkillSourceException("Skill file must start with " + THREE_DASHES);
throw new SkillSourceException(
"Skill file must start with " + THREE_DASHES, SKILL_FORMAT_ERROR);
}
boolean dashClosed = false;
while ((line = reader.readLine()) != null) {
Expand All @@ -169,7 +177,7 @@ private String readInstructions(BufferedReader reader) throws IOException, Skill
}
if (!dashClosed) {
throw new SkillSourceException(
"Skill file frontmatter not properly closed with " + THREE_DASHES);
"Skill file frontmatter not properly closed with " + THREE_DASHES, SKILL_FORMAT_ERROR);
}
// Read the instructions till the end of the file
StringBuilder sb = new StringBuilder();
Expand Down
15 changes: 11 additions & 4 deletions core/src/main/java/com/google/adk/skills/InMemorySkillSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.google.adk.skills;

import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND;
import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -56,7 +58,8 @@ public Single<ImmutableMap<String, Frontmatter>> listFrontmatters() {
public Single<ImmutableList<String>> listResources(String skillName, String resourceDirectory) {
SkillData data = skills.get(skillName);
if (data == null) {
return Single.error(new SkillSourceException("Skill not found: " + skillName));
return Single.error(
new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND));
}
String prefix =
resourceDirectory.isEmpty()
Expand All @@ -67,7 +70,8 @@ public Single<ImmutableList<String>> listResources(String skillName, String reso
&& data.resources().keySet().stream().noneMatch(path -> path.startsWith(prefix))) {
return Single.error(
new SkillSourceException(
"Resource directory not found: " + resourceDirectory + " for skill: " + skillName));
"Resource directory not found: " + resourceDirectory + " for skill: " + skillName,
RESOURCE_NOT_FOUND));
}

return Single.just(
Expand All @@ -92,13 +96,16 @@ public Single<ByteSource> loadResource(String skillName, String resourcePath) {
.map(SkillData::resources)
.mapOptional(m -> Optional.ofNullable(m.get(resourcePath)))
.switchIfEmpty(
Single.error(new SkillSourceException("Resource not found: " + resourcePath)));
Single.error(
new SkillSourceException(
"Resource not found: " + resourcePath, RESOURCE_NOT_FOUND)));
}

private Single<SkillData> getSkillData(String skillName) {
SkillData data = skills.get(skillName);
if (data == null) {
return Single.error(new SkillSourceException("Skill not found: " + skillName));
return Single.error(
new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND));
}
return Single.just(data);
}
Expand Down
27 changes: 20 additions & 7 deletions core/src/main/java/com/google/adk/skills/LocalSkillSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

package com.google.adk.skills;

import static com.google.adk.skills.SkillSourceException.RESOURCE_LOAD_ERROR;
import static com.google.adk.skills.SkillSourceException.RESOURCE_NOT_FOUND;
import static com.google.adk.skills.SkillSourceException.SKILL_LOAD_ERROR;
import static com.google.adk.skills.SkillSourceException.SKILL_NOT_FOUND;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.nio.file.Files.isDirectory;

Expand Down Expand Up @@ -43,14 +47,16 @@ public LocalSkillSource(Path skillsBasePath) {
public Single<ImmutableList<String>> listResources(String skillName, String resourceDirectory) {
Path skillDir = skillsBasePath.resolve(skillName);
if (!isDirectory(skillDir)) {
return Single.error(new SkillSourceException("Skill not found: " + skillName));
return Single.error(
new SkillSourceException("Skill not found: " + skillName, SKILL_NOT_FOUND));
}
Path resourceDir = skillDir.resolve(resourceDirectory);
if (!isDirectory(resourceDir)) {
return Single.error(
new SkillSourceException(
"Resource directory '%s' not found for skill '%s'"
.formatted(resourceDirectory, skillName)));
.formatted(resourceDirectory, skillName),
RESOURCE_NOT_FOUND));
}

return Single.fromCallable(
Expand All @@ -67,7 +73,9 @@ public Single<ImmutableList<String>> listResources(String skillName, String reso
t ->
Single.error(
new SkillSourceException(
"Failed to traverse resource directory: " + resourceDirectory, t)));
"Failed to traverse resource directory: " + resourceDirectory,
RESOURCE_LOAD_ERROR,
t)));
}

@Override
Expand All @@ -78,7 +86,9 @@ protected Flowable<SkillMdPath> listSkills() {
t ->
Flowable.error(
new SkillSourceException(
"Failed to list skills in directory: " + skillsBasePath, t)))
"Failed to list skills in directory: " + skillsBasePath,
SKILL_LOAD_ERROR,
t)))
.filter(Files::isDirectory)
.mapOptional(this::findSkillMd)
.map(skillMd -> new SkillMdPath(skillMd.getParent().getFileName().toString(), skillMd));
Expand All @@ -88,7 +98,8 @@ protected Flowable<SkillMdPath> listSkills() {
protected Single<Path> findResourcePath(String skillName, String resourcePath) {
Path file = skillsBasePath.resolve(skillName).resolve(resourcePath);
if (!Files.exists(file)) {
return Single.error(new SkillSourceException("Resource not found: " + file));
return Single.error(
new SkillSourceException("Resource not found: " + file, RESOURCE_NOT_FOUND));
}
return Single.just(file);
}
Expand All @@ -97,11 +108,13 @@ protected Single<Path> findResourcePath(String skillName, String resourcePath) {
protected Single<Path> findSkillMdPath(String skillName) {
Path skillDir = skillsBasePath.resolve(skillName);
if (!isDirectory(skillDir)) {
return Single.error(new SkillSourceException("Skill directory not found: " + skillName));
return Single.error(
new SkillSourceException("Skill directory not found: " + skillName, SKILL_NOT_FOUND));
}
return Maybe.fromOptional(findSkillMd(skillDir))
.switchIfEmpty(
Single.error(new SkillSourceException("SKILL.md not found in " + skillName)));
Single.error(
new SkillSourceException("SKILL.md not found in " + skillName, SKILL_NOT_FOUND)));
}

@Override
Expand Down
36 changes: 34 additions & 2 deletions core/src/main/java/com/google/adk/skills/SkillSourceException.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,43 @@
*/
public final class SkillSourceException extends Exception {

public SkillSourceException(String message) {
public static final String SKILL_LOAD_ERROR = "SKILL_LOAD_ERROR";
public static final String SKILL_NOT_FOUND = "SKILL_NOT_FOUND";
public static final String SKILL_FORMAT_ERROR = "SKILL_FORMAT_ERROR";
public static final String RESOURCE_LOAD_ERROR = "RESOURCE_LOAD_ERROR";
public static final String RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND";

private final String errorCode;

/**
* Constructs a new exception with the specified detail message and error code.
*
* @param message The detail message.
* @param errorCode The specific error code categorizing the failure.
*/
public SkillSourceException(String message, String errorCode) {
super(message);
this.errorCode = errorCode;
}

public SkillSourceException(String message, Throwable cause) {
/**
* Constructs a new exception with the specified detail message, error code, and cause.
*
* @param message The detail message.
* @param errorCode The specific error code categorizing the failure.
* @param cause The cause.
*/
public SkillSourceException(String message, String errorCode, Throwable cause) {
super(message, cause);
this.errorCode = errorCode;
}

/**
* Returns the error code categorizing the failure.
*
* @return The error code string.
*/
public String getErrorCode() {
return errorCode;
}
}
10 changes: 9 additions & 1 deletion core/src/main/java/com/google/adk/tools/BaseToolset.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@
package com.google.adk.tools;

import com.google.adk.agents.ReadonlyContext;
import com.google.adk.models.LlmRequest;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import java.util.List;
import org.jspecify.annotations.Nullable;

/** Base interface for toolsets. */
public interface BaseToolset extends AutoCloseable {

/** Processes the outgoing {@link LlmRequest.Builder}. */
default Completable processLlmRequest(
LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) {
return Completable.complete();
}

/**
* Return all tools in the toolset based on the provided context.
*
* @param readonlyContext Context used to filter tools available to the agent.
* @return A Single emitting a list of tools available under the specified context.
* @return A Flowable emitting tools available under the specified context.
*/
Flowable<BaseTool> getTools(ReadonlyContext readonlyContext);

Expand Down
Loading