diff --git a/apps/sim/app/api/mcp/oauth/callback/route.ts b/apps/sim/app/api/mcp/oauth/callback/route.ts new file mode 100644 index 00000000000..55cefdfa500 --- /dev/null +++ b/apps/sim/app/api/mcp/oauth/callback/route.ts @@ -0,0 +1,120 @@ +import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js' +import { db } from '@sim/db' +import { mcpServers } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' +import { getSession } from '@/lib/auth' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { + clearState, + clearVerifier, + loadOauthRowByState, + loadPreregisteredClient, + SimMcpOauthProvider, +} from '@/lib/mcp/oauth' +import { mcpService } from '@/lib/mcp/service' + +const logger = createLogger('McpOauthCallbackAPI') + +export const dynamic = 'force-dynamic' + +function escapeHtml(value: string): string { + return value + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''') +} + +function htmlClose(message: string, ok: boolean, serverId?: string): NextResponse { + const safeMessage = escapeHtml(message) + const title = ok ? 'Connected' : 'Connection failed' + const serverIdLiteral = serverId + ? JSON.stringify(serverId).replace(//g, '\\u003e') + : 'undefined' + const body = `${title}

${safeMessage}

` + return new NextResponse(body, { + headers: { 'Content-Type': 'text/html; charset=utf-8' }, + }) +} + +export const GET = withRouteHandler(async (request: NextRequest) => { + const url = new URL(request.url) + const state = url.searchParams.get('state') + const code = url.searchParams.get('code') + const errorParam = url.searchParams.get('error') + + if (errorParam) { + logger.warn(`MCP OAuth callback received error: ${errorParam}`) + return htmlClose(`Authorization failed: ${errorParam}`, false) + } + if (!state || !code) { + return htmlClose('Missing state or code in callback URL.', false) + } + + let serverId: string | undefined + try { + const session = await getSession() + if (!session?.user?.id) { + return htmlClose('You must be signed in to complete authorization.', false) + } + + const row = await loadOauthRowByState(state) + if (!row) { + return htmlClose('Invalid or expired authorization state.', false) + } + serverId = row.mcpServerId + + if (session.user.id !== row.userId) { + return htmlClose( + 'You must be signed in as the same user that initiated the flow.', + false, + serverId + ) + } + + const [server] = await db + .select({ id: mcpServers.id, url: mcpServers.url, workspaceId: mcpServers.workspaceId }) + .from(mcpServers) + .where(and(eq(mcpServers.id, row.mcpServerId), isNull(mcpServers.deletedAt))) + .limit(1) + if (!server || !server.url) { + return htmlClose('Server no longer exists.', false, serverId) + } + + // Burn state before token exchange so a replayed callback cannot reuse it. + await clearState(row.id) + + const preregistered = await loadPreregisteredClient(server.id) + const provider = new SimMcpOauthProvider({ row, preregistered }) + const result = await mcpAuth(provider, { + serverUrl: server.url, + authorizationCode: code, + }) + + await clearVerifier(row.id) + + if (result !== 'AUTHORIZED') { + return htmlClose('Authorization did not complete.', false, server.id) + } + + try { + await mcpService.clearCache(server.workspaceId) + await mcpService.discoverServerTools(row.userId, server.id, server.workspaceId) + } catch (e) { + logger.warn('Post-auth tools refresh failed', toError(e).message) + } + + return htmlClose('Connected. You can close this window.', true, server.id) + } catch (error) { + logger.error('MCP OAuth callback failed', error) + return htmlClose('Authorization failed. Please try again.', false, serverId) + } +}) diff --git a/apps/sim/app/api/mcp/oauth/start/route.ts b/apps/sim/app/api/mcp/oauth/start/route.ts new file mode 100644 index 00000000000..372bb66ec84 --- /dev/null +++ b/apps/sim/app/api/mcp/oauth/start/route.ts @@ -0,0 +1,95 @@ +import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js' +import { db } from '@sim/db' +import { mcpServers } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' +import { startMcpOauthQuerySchema } from '@/lib/api/contracts/mcp' +import { validationErrorResponse } from '@/lib/api/server' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { withMcpAuth } from '@/lib/mcp/middleware' +import { + getOrCreateOauthRow, + loadPreregisteredClient, + McpOauthRedirectRequired, + SimMcpOauthProvider, +} from '@/lib/mcp/oauth' +import { createMcpErrorResponse } from '@/lib/mcp/utils' + +const logger = createLogger('McpOauthStartAPI') + +export const dynamic = 'force-dynamic' + +export const GET = withRouteHandler( + withMcpAuth('write')(async (request: NextRequest, { userId, workspaceId, requestId }) => { + try { + const queryResult = startMcpOauthQuerySchema.safeParse( + Object.fromEntries(new URL(request.url).searchParams) + ) + if (!queryResult.success) { + return validationErrorResponse(queryResult.error) + } + const { serverId } = queryResult.data + + const [server] = await db + .select() + .from(mcpServers) + .where( + and( + eq(mcpServers.id, serverId), + eq(mcpServers.workspaceId, workspaceId), + isNull(mcpServers.deletedAt) + ) + ) + .limit(1) + + if (!server) { + return createMcpErrorResponse(new Error('Server not found'), 'Server not found', 404) + } + if (server.authType !== 'oauth') { + return createMcpErrorResponse( + new Error(`Server authType is "${server.authType}", not oauth`), + 'Server is not configured for OAuth', + 400 + ) + } + if (!server.url) { + return createMcpErrorResponse(new Error('Server has no URL'), 'Missing server URL', 400) + } + + const row = await getOrCreateOauthRow({ + mcpServerId: server.id, + userId, + workspaceId, + }) + const preregistered = await loadPreregisteredClient(server.id) + const provider = new SimMcpOauthProvider({ row, preregistered }) + + try { + const result = await mcpAuth(provider, { serverUrl: server.url }) + if (result === 'AUTHORIZED') { + return NextResponse.json({ status: 'already_authorized' }) + } + return createMcpErrorResponse( + new Error('Provider did not capture redirect URL'), + 'Failed to start OAuth flow', + 500 + ) + } catch (e) { + if (e instanceof McpOauthRedirectRequired) { + logger.info(`[${requestId}] OAuth redirect for server ${serverId}`) + return NextResponse.json({ + status: 'redirect', + authorizationUrl: e.authorizationUrl, + }) + } + throw e + } + } catch (error) { + logger.error(`[${requestId}] Error starting MCP OAuth flow:`, error) + return createMcpErrorResponse(toError(error), 'Failed to start OAuth flow', 500) + } + }) +) diff --git a/apps/sim/app/api/mcp/servers/[id]/route.ts b/apps/sim/app/api/mcp/servers/[id]/route.ts index b2b3b35f5b9..0e28c4a99a0 100644 --- a/apps/sim/app/api/mcp/servers/[id]/route.ts +++ b/apps/sim/app/api/mcp/servers/[id]/route.ts @@ -1,11 +1,12 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit' import { db } from '@sim/db' -import { mcpServers } from '@sim/db/schema' +import { mcpServerOauth, mcpServers } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { updateMcpServerBodySchema } from '@/lib/api/contracts/mcp' +import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { McpDnsResolutionError, @@ -52,8 +53,16 @@ export const PATCH = withRouteHandler( } ) - // Remove workspaceId from body to prevent it from being updated - const { workspaceId: _, ...updateData } = body + const { workspaceId: _, oauthClientSecret, ...updateData } = body + const finalUpdateData: Record = { ...updateData } + if (oauthClientSecret !== undefined) { + finalUpdateData.oauthClientSecret = oauthClientSecret + ? (await encryptSecret(oauthClientSecret)).encrypted + : null + } + if (updateData.oauthClientId !== undefined) { + finalUpdateData.oauthClientId = updateData.oauthClientId || null + } if (updateData.url) { try { @@ -78,9 +87,13 @@ export const PATCH = withRouteHandler( } } - // Get the current server to check if URL is changing const [currentServer] = await db - .select({ url: mcpServers.url }) + .select({ + url: mcpServers.url, + authType: mcpServers.authType, + oauthClientId: mcpServers.oauthClientId, + oauthClientSecret: mcpServers.oauthClientSecret, + }) .from(mcpServers) .where( and( @@ -91,20 +104,60 @@ export const PATCH = withRouteHandler( ) .limit(1) - const [updatedServer] = await db - .update(mcpServers) - .set({ - ...updateData, - updatedAt: new Date(), - }) - .where( - and( - eq(mcpServers.id, serverId), - eq(mcpServers.workspaceId, workspaceId), - isNull(mcpServers.deletedAt) + // Adding OAuth client credentials to a non-OAuth server promotes it + // to OAuth so the connect-with-OAuth UI becomes reachable. + if ( + body.oauthClientId && + currentServer && + currentServer.authType !== 'oauth' && + finalUpdateData.authType === undefined + ) { + finalUpdateData.authType = 'oauth' + } + + const urlChanged = body.url !== undefined && currentServer?.url !== body.url + const clientIdChanged = + body.oauthClientId !== undefined && + (body.oauthClientId || null) !== (currentServer?.oauthClientId ?? null) + let clientSecretChanged = false + if (oauthClientSecret !== undefined) { + if (!oauthClientSecret) { + clientSecretChanged = currentServer?.oauthClientSecret != null + } else if (!currentServer?.oauthClientSecret) { + clientSecretChanged = true + } else { + const currentPlaintext = (await decryptSecret(currentServer.oauthClientSecret)) + .decrypted + clientSecretChanged = currentPlaintext !== oauthClientSecret + } + } + const oauthCredsChanged = clientIdChanged || clientSecretChanged + const shouldClearOauth = urlChanged || oauthCredsChanged + + const updatedServer = await db.transaction(async (tx) => { + const [updated] = await tx + .update(mcpServers) + .set({ + ...finalUpdateData, + updatedAt: new Date(), + }) + .where( + and( + eq(mcpServers.id, serverId), + eq(mcpServers.workspaceId, workspaceId), + isNull(mcpServers.deletedAt) + ) ) - ) - .returning() + .returning() + + if (!updated) return null + + if (shouldClearOauth) { + await tx.delete(mcpServerOauth).where(eq(mcpServerOauth.mcpServerId, serverId)) + } + + return updated + }) if (!updatedServer) { return createMcpErrorResponse( @@ -114,8 +167,15 @@ export const PATCH = withRouteHandler( ) } + if (shouldClearOauth) { + logger.info( + `[${requestId}] Cleared OAuth credentials for server ${serverId} due to ${urlChanged ? 'URL' : 'OAuth credential'} change` + ) + } + const shouldClearCache = - (body.url !== undefined && currentServer?.url !== body.url) || + urlChanged || + oauthCredsChanged || body.enabled !== undefined || body.headers !== undefined || body.timeout !== undefined || @@ -149,7 +209,10 @@ export const PATCH = withRouteHandler( request, }) - return createMcpSuccessResponse({ server: updatedServer }) + const { oauthClientSecret: _secret, ...rest } = updatedServer + return createMcpSuccessResponse({ + server: { ...rest, hasOauthClientSecret: !!_secret }, + }) } catch (error) { logger.error(`[${requestId}] Error updating MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to update MCP server', 500) diff --git a/apps/sim/app/api/mcp/servers/route.ts b/apps/sim/app/api/mcp/servers/route.ts index d2666431506..9ea255daed4 100644 --- a/apps/sim/app/api/mcp/servers/route.ts +++ b/apps/sim/app/api/mcp/servers/route.ts @@ -1,6 +1,6 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit' import { db } from '@sim/db' -import { mcpServers } from '@sim/db/schema' +import { mcpServerOauth, mcpServers } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { generateId } from '@sim/utils/id' @@ -8,6 +8,7 @@ import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { createMcpServerBodySchema, deleteMcpServerByQuerySchema } from '@/lib/api/contracts/mcp' import { validationErrorResponse } from '@/lib/api/server' +import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { McpDnsResolutionError, @@ -17,6 +18,7 @@ import { validateMcpServerSsrf, } from '@/lib/mcp/domain-check' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { detectMcpAuthType } from '@/lib/mcp/oauth' import { mcpService } from '@/lib/mcp/service' import { createMcpErrorResponse, @@ -37,11 +39,16 @@ export const GET = withRouteHandler( try { logger.info(`[${requestId}] Listing MCP servers for workspace ${workspaceId}`) - const servers = await db + const rows = await db .select() .from(mcpServers) .where(and(eq(mcpServers.workspaceId, workspaceId), isNull(mcpServers.deletedAt))) + const servers = rows.map(({ oauthClientSecret: _secret, ...rest }) => ({ + ...rest, + hasOauthClientSecret: !!_secret, + })) + logger.info( `[${requestId}] Listed ${servers.length} MCP servers for workspace ${workspaceId}` ) @@ -105,34 +112,113 @@ export const POST = withRouteHandler( const serverId = body.url ? generateMcpServerId(workspaceId, body.url) : generateId() + const oauthClientSecretProvided = body.oauthClientSecret !== undefined + const oauthClientSecretEncrypted = body.oauthClientSecret + ? (await encryptSecret(body.oauthClientSecret)).encrypted + : null + const oauthClientIdProvided = body.oauthClientId !== undefined + const oauthClientId = body.oauthClientId || null + const [existingServer] = await db - .select({ id: mcpServers.id, deletedAt: mcpServers.deletedAt }) + .select({ + id: mcpServers.id, + deletedAt: mcpServers.deletedAt, + url: mcpServers.url, + authType: mcpServers.authType, + oauthClientId: mcpServers.oauthClientId, + oauthClientSecret: mcpServers.oauthClientSecret, + }) .from(mcpServers) .where(and(eq(mcpServers.id, serverId), eq(mcpServers.workspaceId, workspaceId))) .limit(1) + const urlChanged = existingServer ? existingServer.url !== body.url : true + const hasHeaders = body.headers && Object.keys(body.headers).length > 0 + + let resolvedAuthType: 'none' | 'headers' | 'oauth' = body.authType ?? 'headers' + if (!body.authType) { + if (existingServer && !urlChanged) { + // Preserve existing authType on edits that don't change the URL — re-probing + // can flip a working OAuth+DCR server to 'headers' on a transient 401/timeout. + resolvedAuthType = (existingServer.authType ?? 'headers') as + | 'none' + | 'headers' + | 'oauth' + } else if (body.url && !hasHeaders) { + try { + resolvedAuthType = await detectMcpAuthType(body.url) + logger.info(`[${requestId}] Probed ${body.url}: authType=${resolvedAuthType}`) + } catch (e) { + logger.warn(`[${requestId}] Probe failed for ${body.url}, defaulting to headers`, e) + resolvedAuthType = 'headers' + } + } + } + + // User-supplied client credentials imply OAuth; pin authType regardless of probe. + if (body.oauthClientId) resolvedAuthType = 'oauth' + if (existingServer) { logger.info( `[${requestId}] Server with ID ${serverId} already exists, updating instead of creating` ) - await db - .update(mcpServers) - .set({ + const clientIdChanged = + oauthClientIdProvided && + (oauthClientId || null) !== (existingServer.oauthClientId ?? null) + let clientSecretChanged = false + if (oauthClientSecretProvided) { + if (!body.oauthClientSecret) { + clientSecretChanged = existingServer.oauthClientSecret != null + } else if (!existingServer.oauthClientSecret) { + clientSecretChanged = true + } else { + const currentPlaintext = (await decryptSecret(existingServer.oauthClientSecret)) + .decrypted + clientSecretChanged = currentPlaintext !== body.oauthClientSecret + } + } + const oauthCredsChanged = clientIdChanged || clientSecretChanged + + const isRevival = existingServer.deletedAt !== null + const shouldClearOauth = urlChanged || oauthCredsChanged || isRevival + + await db.transaction(async (tx) => { + if (shouldClearOauth) { + await tx.delete(mcpServerOauth).where(eq(mcpServerOauth.mcpServerId, serverId)) + } + const updateValues: Record = { name: body.name, description: body.description, transport: body.transport, url: body.url, + authType: resolvedAuthType, headers: body.headers || {}, timeout: body.timeout || 30000, retries: body.retries || 3, enabled: body.enabled !== false, - connectionStatus: 'connected', - lastConnected: new Date(), + connectionStatus: resolvedAuthType === 'oauth' ? 'disconnected' : 'connected', + lastConnected: resolvedAuthType === 'oauth' ? null : new Date(), updatedAt: new Date(), deletedAt: null, - }) - .where(eq(mcpServers.id, serverId)) + } + if (oauthClientIdProvided) updateValues.oauthClientId = oauthClientId + if (oauthClientSecretProvided) { + updateValues.oauthClientSecret = oauthClientSecretEncrypted + } + await tx.update(mcpServers).set(updateValues).where(eq(mcpServers.id, serverId)) + }) + + if (shouldClearOauth) { + const reason = isRevival + ? 'server revival' + : urlChanged + ? 'URL change' + : 'OAuth credential change' + logger.info( + `[${requestId}] Cleared OAuth credentials for server ${serverId} due to ${reason}` + ) + } await mcpService.clearCache(workspaceId) @@ -140,7 +226,10 @@ export const POST = withRouteHandler( `[${requestId}] Successfully updated MCP server: ${body.name} (ID: ${serverId})` ) - return createMcpSuccessResponse({ serverId, updated: true }, 200) + return createMcpSuccessResponse( + { serverId, updated: true, authType: resolvedAuthType }, + 200 + ) } await db @@ -153,12 +242,15 @@ export const POST = withRouteHandler( description: body.description, transport: body.transport, url: body.url, + authType: resolvedAuthType, + oauthClientId, + oauthClientSecret: oauthClientSecretEncrypted, headers: body.headers || {}, timeout: body.timeout || 30000, retries: body.retries || 3, enabled: body.enabled !== false, - connectionStatus: 'connected', - lastConnected: new Date(), + connectionStatus: resolvedAuthType === 'oauth' ? 'disconnected' : 'connected', + lastConnected: resolvedAuthType === 'oauth' ? null : new Date(), createdAt: new Date(), updatedAt: new Date(), }) @@ -178,9 +270,7 @@ export const POST = withRouteHandler( transport: body.transport, workspaceId, }) - } catch (_e) { - // Silently fail - } + } catch (_e) {} const sourceParam = body.source as string | undefined const source = @@ -217,7 +307,7 @@ export const POST = withRouteHandler( request, }) - return createMcpSuccessResponse({ serverId }, 201) + return createMcpSuccessResponse({ serverId, authType: resolvedAuthType }, 201) } catch (error) { logger.error(`[${requestId}] Error registering MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to register MCP server', 500) diff --git a/apps/sim/app/api/mcp/tools/execute/route.ts b/apps/sim/app/api/mcp/tools/execute/route.ts index d9458deceab..8599a5fcadf 100644 --- a/apps/sim/app/api/mcp/tools/execute/route.ts +++ b/apps/sim/app/api/mcp/tools/execute/route.ts @@ -1,5 +1,7 @@ +import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { createLogger } from '@sim/logger' import type { NextRequest } from 'next/server' +import { NextResponse } from 'next/server' import { mcpToolExecutionBodySchema } from '@/lib/api/contracts/mcp' import { getHighestPrioritySubscription } from '@/lib/billing/core/plan' import { getExecutionTimeout } from '@/lib/core/execution-limits' @@ -7,8 +9,14 @@ import type { SubscriptionPlan } from '@/lib/core/rate-limiter/types' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { SIM_VIA_HEADER } from '@/lib/execution/call-chain' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { McpOauthRedirectRequired } from '@/lib/mcp/oauth' import { mcpService } from '@/lib/mcp/service' -import type { McpTool, McpToolCall, McpToolResult } from '@/lib/mcp/types' +import { + McpOauthAuthorizationRequiredError, + type McpTool, + type McpToolCall, + type McpToolResult, +} from '@/lib/mcp/types' import { categorizeError, createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils' import { assertPermissionsAllowed, @@ -43,6 +51,7 @@ function hasType(prop: unknown): prop is SchemaProperty { */ export const POST = withRouteHandler( withMcpAuth('read')(async (request: NextRequest, { userId, workspaceId, requestId }) => { + let serverId: string | undefined try { const rawBody = getParsedBody(request) ?? (await request.json()) const parsedBody = mcpToolExecutionBodySchema.safeParse(rawBody) @@ -63,7 +72,8 @@ export const POST = withRouteHandler( userId: userId, }) - const { serverId, toolName, arguments: rawArgs } = body + const { toolName, arguments: rawArgs } = body + serverId = body.serverId const args = rawArgs || {} try { @@ -101,7 +111,8 @@ export const POST = withRouteHandler( if (tool.inputSchema?.properties) { for (const [paramName, paramSchema] of Object.entries(tool.inputSchema.properties)) { - const schema = paramSchema as any + const schema = hasType(paramSchema) ? paramSchema : null + if (!schema) continue const value = args[paramName] if (value === undefined || value === null) { @@ -185,12 +196,18 @@ export const POST = withRouteHandler( extraHeaders[SIM_VIA_HEADER] = simViaHeader } + let timeoutHandle: ReturnType | undefined const result = await Promise.race([ mcpService.executeTool(userId, serverId, toolCall, workspaceId, extraHeaders), - new Promise((_, reject) => - setTimeout(() => reject(new Error('Tool execution timeout')), executionTimeout) - ), - ]) + new Promise((_, reject) => { + timeoutHandle = setTimeout( + () => reject(new Error('Tool execution timeout')), + executionTimeout + ) + }), + ]).finally(() => { + if (timeoutHandle !== undefined) clearTimeout(timeoutHandle) + }) const transformedResult = transformToolResult(result) @@ -218,6 +235,27 @@ export const POST = withRouteHandler( return createMcpSuccessResponse(transformedResult) } catch (error) { + if ( + error instanceof McpOauthAuthorizationRequiredError || + error instanceof McpOauthRedirectRequired || + error instanceof UnauthorizedError + ) { + const errorServerId = + error instanceof McpOauthAuthorizationRequiredError ? error.serverId : serverId + logger.warn(`[${requestId}] OAuth re-authorization required for MCP tool execution`, { + serverId: errorServerId, + }) + return NextResponse.json( + { + success: false, + error: 'OAuth re-authorization required', + code: 'reauth_required', + serverId: errorServerId, + }, + { status: 401 } + ) + } + logger.error(`[${requestId}] Error executing MCP tool:`, error) const { message, status } = categorizeError(error) diff --git a/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/mcp-server-form-modal/mcp-server-form-modal.tsx b/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/mcp-server-form-modal/mcp-server-form-modal.tsx index 5e6f8db16a4..55f9a5f854f 100644 --- a/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/mcp-server-form-modal/mcp-server-form-modal.tsx +++ b/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/components/mcp-server-form-modal/mcp-server-form-modal.tsx @@ -1,7 +1,9 @@ 'use client' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useEffect, useRef, useState } from 'react' import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { ChevronDown, ChevronRight } from 'lucide-react' import { Button, Input as EmcnInput, @@ -35,6 +37,9 @@ interface McpServerFormData { url?: string timeout?: number headers?: HeaderEntry[] + oauthClientId?: string + oauthClientSecret?: string + hasOauthClientSecret?: boolean } export interface McpServerFormConfig { @@ -43,6 +48,8 @@ export interface McpServerFormConfig { url: string headers: Record timeout: number + oauthClientId?: string + oauthClientSecret?: string } export interface McpServerFormModalProps { @@ -322,6 +329,9 @@ export function McpServerFormModal({ const [urlScrollLeft, setUrlScrollLeft] = useState(0) const [headerScrollLeft, setHeaderScrollLeft] = useState>({}) + const [showAdvanced, setShowAdvanced] = useState(false) + const [oauthClientSecretTouched, setOauthClientSecretTouched] = useState(false) + const [prevOpen, setPrevOpen] = useState(false) if (open && !prevOpen) { const data = initialData ?? DEFAULT_FORM_DATA @@ -337,6 +347,8 @@ export function McpServerFormModal({ setActiveHeaderIndex(null) setUrlScrollLeft(0) setHeaderScrollLeft({}) + setShowAdvanced(!!(data.oauthClientId || data.oauthClientSecret || data.hasOauthClientSecret)) + setOauthClientSecretTouched(false) } if (open !== prevOpen) { setPrevOpen(open) @@ -350,76 +362,72 @@ export function McpServerFormModal({ } }, [open, clearTestResult]) - const resetEnvVarState = useCallback(() => { + const resetEnvVarState = () => { setShowEnvVars(false) setActiveInputField(null) setActiveHeaderIndex(null) - }, []) - - const handleInputChange = useCallback( - (field: InputFieldType, value: string, headerIndex?: number) => { - const input = document.activeElement as HTMLInputElement - const pos = input?.selectionStart || 0 - setCursorPosition(pos) - - if (testResult) clearTestResult() - if (submitError) setSubmitError(null) - - const envVarTrigger = checkEnvVarTrigger(value, pos) - setShowEnvVars(envVarTrigger.show) - setEnvSearchTerm(envVarTrigger.show ? envVarTrigger.searchTerm : '') - - if (envVarTrigger.show) { - setActiveInputField(field) - setActiveHeaderIndex(headerIndex ?? null) - } else { - resetEnvVarState() - } + } - if (field === 'url') { - setFormData((prev) => ({ ...prev, url: value })) - } else if (headerIndex !== undefined) { - const headerField = field === 'header-key' ? 'key' : 'value' - setFormData((prev) => ({ - ...prev, - headers: updateHeadersArray(prev.headers || [], headerIndex, headerField, value), - })) - } - }, - [testResult, clearTestResult, submitError, resetEnvVarState] - ) + const handleInputChange = (field: InputFieldType, value: string, headerIndex?: number) => { + const input = document.activeElement as HTMLInputElement + const pos = input?.selectionStart || 0 + setCursorPosition(pos) - const handleEnvVarSelect = useCallback( - (newValue: string) => { - if (activeInputField === 'url') { - setFormData((prev) => ({ ...prev, url: newValue })) - } else if (activeHeaderIndex !== null) { - const field = activeInputField === 'header-key' ? 'key' : 'value' - const processedValue = field === 'key' ? newValue.replace(/[{}]/g, '') : newValue - setFormData((prev) => ({ - ...prev, - headers: updateHeadersArray(prev.headers || [], activeHeaderIndex, field, processedValue), - })) - } + if (testResult) clearTestResult() + if (submitError) setSubmitError(null) + + const envVarTrigger = checkEnvVarTrigger(value, pos) + setShowEnvVars(envVarTrigger.show) + setEnvSearchTerm(envVarTrigger.show ? envVarTrigger.searchTerm : '') + + if (envVarTrigger.show) { + setActiveInputField(field) + setActiveHeaderIndex(headerIndex ?? null) + } else { resetEnvVarState() - }, - [activeInputField, activeHeaderIndex, resetEnvVarState] - ) + } - const handleHeaderScroll = useCallback((key: string, sl: number) => { + if (field === 'url') { + setFormData((prev) => ({ ...prev, url: value })) + } else if (headerIndex !== undefined) { + const headerField = field === 'header-key' ? 'key' : 'value' + setFormData((prev) => ({ + ...prev, + headers: updateHeadersArray(prev.headers || [], headerIndex, headerField, value), + })) + } + } + + const handleEnvVarSelect = (newValue: string) => { + if (activeInputField === 'url') { + setFormData((prev) => ({ ...prev, url: newValue })) + } else if (activeHeaderIndex !== null) { + const field = activeInputField === 'header-key' ? 'key' : 'value' + const processedValue = field === 'key' ? newValue.replace(/[{}]/g, '') : newValue + setFormData((prev) => ({ + ...prev, + headers: updateHeadersArray(prev.headers || [], activeHeaderIndex, field, processedValue), + })) + } + resetEnvVarState() + } + + const handleHeaderScroll = (key: string, sl: number) => { setHeaderScrollLeft((prev) => ({ ...prev, [key]: sl })) - }, []) + } const isDomainBlocked = !!formData.url?.trim() && !isDomainAllowed(formData.url, allowedMcpDomains) const isFormValid = !!(formData.name.trim() && formData.url?.trim()) const testButtonLabel = getTestButtonLabel(testResult, isTestingConnection) - const hasChanges = useMemo(() => { + const computeHasChanges = (): boolean => { if (mode === 'add') return true if (formData.name !== originalData.name) return true if (formData.url !== originalData.url) return true if (formData.transport !== originalData.transport) return true + if ((formData.oauthClientId || '') !== (originalData.oauthClientId || '')) return true + if (oauthClientSecretTouched) return true const currentHeaders = formData.headers || [] const origHeaders = originalData.headers || [] if (currentHeaders.length !== origHeaders.length) return true @@ -431,49 +439,49 @@ export function McpServerFormModal({ return true } return false - }, [mode, formData, originalData]) - - const parseJsonConfig = useCallback( - (json: string): { name: string; url: string; headers: Record } | null => { - try { - const parsed = JSON.parse(json) - - if (parsed.mcpServers && typeof parsed.mcpServers === 'object') { - const entries = Object.entries(parsed.mcpServers) - if (entries.length === 0) { - setJsonError('No servers found in mcpServers') - return null - } - if (entries.length > 1) { - setJsonError( - `Only the first server ("${entries[0][0]}") will be imported. Paste each config separately to add others.` - ) - } - const [name, config] = entries[0] as [string, Record] - if (!config.url || typeof config.url !== 'string') { - setJsonError('Server config must include a "url" field') - return null - } - if (entries.length <= 1) setJsonError(null) - return { name, url: config.url, headers: extractStringHeaders(config.headers) } - } + } + const hasChanges = computeHasChanges() + + const parseJsonConfig = ( + json: string + ): { name: string; url: string; headers: Record } | null => { + try { + const parsed = JSON.parse(json) - if (parsed.url && typeof parsed.url === 'string') { - setJsonError(null) - return { name: '', url: parsed.url, headers: extractStringHeaders(parsed.headers) } + if (parsed.mcpServers && typeof parsed.mcpServers === 'object') { + const entries = Object.entries(parsed.mcpServers) + if (entries.length === 0) { + setJsonError('No servers found in mcpServers') + return null + } + if (entries.length > 1) { + setJsonError( + `Only the first server ("${entries[0][0]}") will be imported. Paste each config separately to add others.` + ) + } + const [name, config] = entries[0] as [string, Record] + if (!config.url || typeof config.url !== 'string') { + setJsonError('Server config must include a "url" field') + return null } + if (entries.length <= 1) setJsonError(null) + return { name, url: config.url, headers: extractStringHeaders(config.headers) } + } - setJsonError('JSON must contain "mcpServers" or a "url" field') - return null - } catch { - setJsonError('Invalid JSON') - return null + if (parsed.url && typeof parsed.url === 'string') { + setJsonError(null) + return { name: '', url: parsed.url, headers: extractStringHeaders(parsed.headers) } } - }, - [] - ) - const handleTestConnection = useCallback(async () => { + setJsonError('JSON must contain "mcpServers" or a "url" field') + return null + } catch { + setJsonError('Invalid JSON') + return null + } + } + + const handleTestConnection = async () => { if (!isFormValid) return await testConnection({ @@ -484,15 +492,20 @@ export function McpServerFormModal({ timeout: formData.timeout, workspaceId, }) - }, [formData, isFormValid, testConnection, workspaceId]) + } - const handleSubmitForm = useCallback(async () => { + const handleSubmitForm = async () => { if (!isFormValid || isDomainBlocked) return setIsSubmitting(true) setSubmitError(null) try { const headers = headersToRecord(formData.headers) + const oauthClientId = formData.oauthClientId?.trim() + const oauthClientSecret = formData.oauthClientSecret?.trim() + const originalClientId = (originalData.oauthClientId || '').trim() + const oauthClientIdChanged = (oauthClientId || '') !== originalClientId + const connectionResult = await testConnection({ name: formData.name, transport: formData.transport, @@ -503,10 +516,18 @@ export function McpServerFormModal({ }) if (!connectionResult.success) { - setSubmitError( - connectionResult.error || 'Connection test failed. Please check the URL and try again.' - ) - return + const errorText = (connectionResult.error || '').toLowerCase() + const looksLikeAuthRequired = + /\b401\b/.test(errorText) || + errorText.includes('unauthorized') || + errorText.includes('oauth') || + errorText.includes('authentication') + if (!looksLikeAuthRequired) { + setSubmitError( + connectionResult.error || 'Connection test failed. Please check the URL and try again.' + ) + return + } } await onSubmit({ @@ -515,19 +536,30 @@ export function McpServerFormModal({ url: formData.url!, headers, timeout: formData.timeout || 30000, + oauthClientId: + mode === 'edit' + ? oauthClientIdChanged + ? (oauthClientId ?? '') + : undefined + : oauthClientId || undefined, + oauthClientSecret: + mode === 'edit' + ? oauthClientSecretTouched + ? (oauthClientSecret ?? '') + : undefined + : oauthClientSecret || undefined, }) onOpenChange(false) } catch (error) { - const message = error instanceof Error ? error.message : 'Failed to save server' - setSubmitError(message) + setSubmitError(toError(error).message || 'Failed to save server') logger.error('Failed to save MCP server:', error) } finally { setIsSubmitting(false) } - }, [formData, isFormValid, isDomainBlocked, testConnection, workspaceId, onSubmit, onOpenChange]) + } - const handleSubmitJson = useCallback(async () => { + const handleSubmitJson = async () => { const config = parseJsonConfig(jsonInput) if (!config) return @@ -570,21 +602,12 @@ export function McpServerFormModal({ onOpenChange(false) } catch (error) { - const message = error instanceof Error ? error.message : 'Failed to save server' - setSubmitError(message) + setSubmitError(toError(error).message || 'Failed to save server') logger.error('Failed to save MCP server from JSON:', error) } finally { setIsSubmitting(false) } - }, [ - jsonInput, - parseJsonConfig, - allowedMcpDomains, - testConnection, - workspaceId, - onSubmit, - onOpenChange, - ]) + } const isSubmitDisabled = isSubmitting || !isFormValid || isDomainBlocked || (mode === 'edit' && !hasChanges) @@ -676,6 +699,53 @@ export function McpServerFormModal({ ))} + + + {showAdvanced && ( +
+ + { + if (testResult) clearTestResult() + if (submitError) setSubmitError(null) + setFormData((prev) => ({ ...prev, oauthClientId: e.target.value })) + }} + className='h-9' + /> + + + { + if (testResult) clearTestResult() + if (submitError) setSubmitError(null) + setOauthClientSecretTouched(true) + setFormData((prev) => ({ ...prev, oauthClientSecret: e.target.value })) + }} + className='h-9' + /> + +

+ Only needed for servers that don't support automatic client registration. +

+
+ )} )} @@ -710,7 +780,7 @@ export function McpServerFormModal({ )}
- {formMode === 'json' ? ( diff --git a/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/mcp.tsx b/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/mcp.tsx index 0d2c047549a..e0bb726cfb3 100644 --- a/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/mcp.tsx +++ b/apps/sim/app/workspace/[workspaceId]/settings/components/mcp/mcp.tsx @@ -1,6 +1,6 @@ 'use client' -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useEffect, useState } from 'react' import { createLogger } from '@sim/logger' import { ChevronDown, Plus, Search } from 'lucide-react' import { useParams } from 'next/navigation' @@ -25,6 +25,7 @@ import { type McpToolIssue, } from '@/lib/mcp/tool-validation' import type { McpTransport } from '@/lib/mcp/types' +import { useMcpOauthPopup } from '@/hooks/mcp/use-mcp-oauth-popup' import { type McpServer, type McpTool, @@ -100,7 +101,10 @@ function ServerListItem({ ({transportLabel})

{isRefreshing ? 'Refreshing...' @@ -121,14 +125,29 @@ function ServerListItem({ ) } +function buildEditInitialData(server: McpServer) { + const entries: { key: string; value: string }[] = server.headers + ? Object.entries(server.headers).map(([key, value]) => ({ key, value })) + : [] + if (entries.length === 0) entries.push({ key: '', value: '' }) + const last = entries[entries.length - 1] + if (last.key !== '' || last.value !== '') entries.push({ key: '', value: '' }) + + return { + name: server.name || '', + transport: (server.transport as McpTransport) || 'streamable-http', + url: server.url || '', + timeout: 30000, + headers: entries, + oauthClientId: server.oauthClientId || undefined, + hasOauthClientSecret: server.hasOauthClientSecret === true, + } +} + interface MCPProps { initialServerId?: string | null } -/** - * MCP Settings component for managing Model Context Protocol servers. - * Handles server CRUD operations, connection testing, and environment variable integration. - */ export function MCP({ initialServerId }: MCPProps) { const params = useParams() const workspaceId = params.workspaceId as string @@ -145,7 +164,8 @@ export function MCP({ initialServerId }: MCPProps) { isFetching: toolsFetching, } = useMcpToolsQuery(workspaceId) const { data: storedTools = [], refetch: refetchStoredTools } = useStoredMcpTools(workspaceId) - const forceRefreshTools = useForceRefreshMcpTools() + const forceRefreshToolsMutation = useForceRefreshMcpTools() + const forceRefreshTools = forceRefreshToolsMutation.mutate const createServerMutation = useCreateMcpServer() const deleteServerMutation = useDeleteMcpServer() const refreshServerMutation = useRefreshMcpServer() @@ -154,23 +174,16 @@ export function MCP({ initialServerId }: MCPProps) { const { data: allowedMcpDomains = null } = useAllowedMcpDomains() const [showAddModal, setShowAddModal] = useState(false) - const [showEditModal, setShowEditModal] = useState(false) - const [editInitialData, setEditInitialData] = useState< - | { - name: string - transport: McpTransport - url?: string - timeout?: number - headers?: { key: string; value: string }[] - } - | undefined - >(undefined) + const [editingServerId, setEditingServerId] = useState(null) const [searchTerm, setSearchTerm] = useState('') const [deletingServers, setDeletingServers] = useState>(() => new Set()) + const { connectingServers: connectingOauthServers, startOauthForServer } = useMcpOauthPopup({ + workspaceId, + }) - const [showDeleteDialog, setShowDeleteDialog] = useState(false) - const [serverToDelete, setServerToDelete] = useState<{ id: string; name: string } | null>(null) + const [serverToDeleteId, setServerToDeleteId] = useState(null) + const showDeleteDialog = serverToDeleteId !== null const [selectedServerId, setSelectedServerId] = useState(initialServerId ?? null) @@ -183,28 +196,23 @@ export function MCP({ initialServerId }: MCPProps) { } }, []) - const [refreshingServers, setRefreshingServers] = useState< - Record - >({}) const [expandedTools, setExpandedTools] = useState>(() => new Set()) - const handleRemoveServer = useCallback((serverId: string, serverName: string) => { - setServerToDelete({ id: serverId, name: serverName }) - setShowDeleteDialog(true) - }, []) + const handleRemoveServer = (serverId: string) => { + setServerToDeleteId(serverId) + } - const confirmDeleteServer = useCallback(async () => { - if (!serverToDelete) return + const confirmDeleteServer = async () => { + if (!serverToDeleteId) return - setShowDeleteDialog(false) - const { id: serverId, name: serverName } = serverToDelete - setServerToDelete(null) + const serverId = serverToDeleteId + setServerToDeleteId(null) setDeletingServers((prev) => new Set(prev).add(serverId)) try { await deleteServerMutation.mutateAsync({ workspaceId, serverId }) - logger.info(`Removed MCP server: ${serverName}`) + logger.info(`Removed MCP server: ${serverId}`) } catch (error) { logger.error('Failed to remove MCP server:', error) } finally { @@ -214,43 +222,36 @@ export function MCP({ initialServerId }: MCPProps) { return newSet }) } - }, [serverToDelete, deleteServerMutation, workspaceId]) - - const toolsByServer = useMemo(() => { - return (mcpToolsData || []).reduce( - (acc, tool) => { - if (!tool?.serverId) return acc - if (!acc[tool.serverId]) { - acc[tool.serverId] = [] - } - acc[tool.serverId].push(tool) - return acc - }, - {} as Record - ) - }, [mcpToolsData]) - - const filteredServers = useMemo(() => { - return (servers || []).filter((server) => - server.name?.toLowerCase().includes(searchTerm.toLowerCase()) - ) - }, [servers, searchTerm]) + } - const handleViewDetails = useCallback( - (serverId: string) => { - setSelectedServerId(serverId) - forceRefreshTools(workspaceId) - refetchStoredTools() + const toolsByServer = (mcpToolsData || []).reduce( + (acc, tool) => { + if (!tool?.serverId) return acc + if (!acc[tool.serverId]) { + acc[tool.serverId] = [] + } + acc[tool.serverId].push(tool) + return acc }, - [workspaceId, forceRefreshTools, refetchStoredTools] + {} as Record + ) + + const filteredServers = (servers || []).filter((server) => + server.name?.toLowerCase().includes(searchTerm.toLowerCase()) ) - const handleBackToList = useCallback(() => { + const handleViewDetails = (serverId: string) => { + setSelectedServerId(serverId) + forceRefreshTools(workspaceId) + refetchStoredTools() + } + + const handleBackToList = () => { setSelectedServerId(null) setExpandedTools(new Set()) - }, []) + } - const toggleToolExpanded = useCallback((toolName: string) => { + const toggleToolExpanded = (toolName: string) => { setExpandedTools((prev) => { const newSet = new Set(prev) if (newSet.has(toolName)) { @@ -260,131 +261,109 @@ export function MCP({ initialServerId }: MCPProps) { } return newSet }) - }, []) + } - const handleRefreshServer = useCallback( - async (serverId: string) => { - try { - setRefreshingServers((prev) => ({ ...prev, [serverId]: { status: 'refreshing' } })) - const result = await refreshServerMutation.mutateAsync({ workspaceId, serverId }) - logger.info( - `Refreshed MCP server: ${serverId}, workflows updated: ${result.workflowsUpdated}` - ) - - const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId - if (activeWorkflowId && result.updatedWorkflowIds?.includes(activeWorkflowId)) { - logger.info(`Active workflow ${activeWorkflowId} was updated, reloading subblock values`) - try { - const { data: workflowData } = await requestJson(getWorkflowStateContract, { - params: { id: activeWorkflowId }, - }) - if (workflowData?.state?.blocks) { - useSubBlockStore - .getState() - .initializeFromWorkflow( - activeWorkflowId, - workflowData.state.blocks as Record - ) - } - } catch (reloadError) { - logger.warn('Failed to reload workflow subblock values:', reloadError) - } - } + const handleRefreshServer = async (serverId: string) => { + try { + const result = await refreshServerMutation.mutateAsync({ workspaceId, serverId }) + logger.info( + `Refreshed MCP server: ${serverId}, workflows updated: ${result.workflowsUpdated}` + ) - setRefreshingServers((prev) => ({ - ...prev, - [serverId]: { status: 'refreshed', workflowsUpdated: result.workflowsUpdated }, - })) - setTimeout(() => { - setRefreshingServers((prev) => { - const newState = { ...prev } - delete newState[serverId] - return newState + const activeWorkflowId = useWorkflowRegistry.getState().activeWorkflowId + if (activeWorkflowId && result.updatedWorkflowIds?.includes(activeWorkflowId)) { + logger.info(`Active workflow ${activeWorkflowId} was updated, reloading subblock values`) + try { + const { data: workflowData } = await requestJson(getWorkflowStateContract, { + params: { id: activeWorkflowId }, }) - }, 3000) - } catch (error) { - logger.error('Failed to refresh MCP server:', error) - setRefreshingServers((prev) => { - const newState = { ...prev } - delete newState[serverId] - return newState - }) + if (workflowData?.state?.blocks) { + useSubBlockStore + .getState() + .initializeFromWorkflow( + activeWorkflowId, + workflowData.state.blocks as Record + ) + } + } catch (reloadError) { + logger.warn('Failed to reload workflow subblock values:', reloadError) + } } - }, - [refreshServerMutation, workspaceId] - ) - - const handleOpenEditModal = useCallback((server: McpServer) => { - const headers: { key: string; value: string }[] = server.headers - ? Object.entries(server.headers).map(([key, value]) => ({ key, value })) - : [{ key: '', value: '' }] - if (headers.length === 0) headers.push({ key: '', value: '' }) - - const lastHeader = headers[headers.length - 1] - if (lastHeader.key !== '' || lastHeader.value !== '') { - headers.push({ key: '', value: '' }) + } catch (error) { + logger.error('Failed to refresh MCP server:', error) } + } - setEditInitialData({ - name: server.name || '', - transport: (server.transport as McpTransport) || 'streamable-http', - url: server.url || '', - timeout: 30000, - headers, - }) - setShowEditModal(true) - }, []) - - const selectedServer = useMemo(() => { + useEffect(() => { + if (!refreshServerMutation.isSuccess) return + const timeout = window.setTimeout(() => refreshServerMutation.reset(), 3000) + return () => window.clearTimeout(timeout) + // eslint-disable-next-line react-hooks/exhaustive-deps -- mutation object is unstable; isSuccess flag is the trigger + }, [refreshServerMutation.isSuccess]) + + const refreshingServerId = refreshServerMutation.isPending + ? refreshServerMutation.variables?.serverId + : null + const refreshedServerId = refreshServerMutation.isSuccess + ? refreshServerMutation.variables?.serverId + : null + const refreshedWorkflowsUpdated = refreshServerMutation.data?.workflowsUpdated + + const editingServer = editingServerId + ? (servers.find((s) => s.id === editingServerId) as McpServer | undefined) + : undefined + const editInitialData = editingServer ? buildEditInitialData(editingServer) : undefined + + const selectedServer = (() => { if (!selectedServerId) return null const server = servers.find((s) => s.id === selectedServerId) as McpServer | undefined if (!server) return null const serverTools = (toolsByServer[selectedServerId] || []) as McpTool[] return { server, tools: serverTools } - }, [selectedServerId, servers, toolsByServer]) + })() + + const getStoredToolIssues = ( + serverId: string, + toolName: string + ): { issue: McpToolIssue; workflowName: string }[] => { + const relevantStoredTools = storedTools.filter( + (st) => st.serverId === serverId && st.toolName === toolName + ) - const getStoredToolIssues = useCallback( - (serverId: string, toolName: string): { issue: McpToolIssue; workflowName: string }[] => { - const relevantStoredTools = storedTools.filter( - (st) => st.serverId === serverId && st.toolName === toolName + const serverStates = servers.map((s) => ({ + id: s.id, + url: s.url, + connectionStatus: s.connectionStatus, + lastError: s.lastError || undefined, + })) + + const discoveredTools = mcpToolsData.map((t) => ({ + serverId: t.serverId, + name: t.name, + inputSchema: t.inputSchema, + })) + + const issues: { issue: McpToolIssue; workflowName: string }[] = [] + + for (const storedTool of relevantStoredTools) { + const issue = getMcpToolIssue( + { + serverId: storedTool.serverId, + serverUrl: storedTool.serverUrl, + toolName: storedTool.toolName, + schema: storedTool.schema, + }, + serverStates, + discoveredTools ) - const serverStates = servers.map((s) => ({ - id: s.id, - url: s.url, - connectionStatus: s.connectionStatus, - lastError: s.lastError || undefined, - })) - - const discoveredTools = mcpToolsData.map((t) => ({ - serverId: t.serverId, - name: t.name, - inputSchema: t.inputSchema, - })) - - const issues: { issue: McpToolIssue; workflowName: string }[] = [] - - for (const storedTool of relevantStoredTools) { - const issue = getMcpToolIssue( - { - serverId: storedTool.serverId, - serverUrl: storedTool.serverUrl, - toolName: storedTool.toolName, - schema: storedTool.schema, - }, - serverStates, - discoveredTools - ) - - if (issue) { - issues.push({ issue, workflowName: storedTool.workflowName }) - } + if (issue) { + issues.push({ issue, workflowName: storedTool.workflowName }) } + } - return issues - }, - [storedTools, servers, mcpToolsData] - ) + return issues + } const error = toolsError || serversError const hasServers = servers && servers.length > 0 @@ -420,12 +399,32 @@ export function MCP({ initialServerId }: MCPProps) { {server.connectionStatus === 'error' && (

Status -

+

{server.lastError || 'Unable to connect'}

)} + {server.authType === 'oauth' && server.connectionStatus !== 'connected' && ( +
+ + Authentication + +
+ +
+
+ )} +
Tools ({tools.length}) @@ -448,11 +447,12 @@ export function MCP({ initialServerId }: MCPProps) { key={tool.name} className='overflow-hidden rounded-md border bg-[var(--surface-3)]' > - + {isExpanded && hasParams && (
@@ -561,25 +561,27 @@ export function MCP({ initialServerId }: MCPProps) { -
{ + if (!open) setEditingServerId(null) + }} mode='edit' initialData={editInitialData} onSubmit={async (config) => { @@ -618,7 +620,7 @@ export function MCP({ initialServerId }: MCPProps) { /> @@ -626,7 +628,7 @@ export function MCP({ initialServerId }: MCPProps) {
{error ? (
-

+

{error instanceof Error ? error.message : 'Failed to load MCP servers'}

@@ -654,8 +656,8 @@ export function MCP({ initialServerId }: MCPProps) { tools={tools} isDeleting={deletingServers.has(server.id)} isLoadingTools={isLoadingTools} - isRefreshing={refreshingServers[server.id]?.status === 'refreshing'} - onRemove={() => handleRemoveServer(server.id, server.name || 'this server')} + isRefreshing={refreshingServerId === server.id} + onRemove={() => handleRemoveServer(server.id)} onViewDetails={() => handleViewDetails(server.id)} /> ) @@ -675,28 +677,38 @@ export function MCP({ initialServerId }: MCPProps) { onOpenChange={setShowAddModal} mode='add' onSubmit={async (config) => { - await createServerMutation.mutateAsync({ + const result = await createServerMutation.mutateAsync({ workspaceId, config: { ...config, enabled: true }, }) + if (result.authType === 'oauth') { + await startOauthForServer(result.serverId) + } }} workspaceId={workspaceId} availableEnvVars={availableEnvVars} allowedMcpDomains={allowedMcpDomains} /> - + { + if (!open) setServerToDeleteId(null) + }} + > Delete MCP Server

Are you sure you want to delete{' '} - {serverToDelete?.name} + + {servers.find((s) => s.id === serverToDeleteId)?.name || 'this server'} + ? This action cannot be undone.

-