From e27f3085973722407518ea4822fb3e0a2b41df9c Mon Sep 17 00:00:00 2001 From: Matti4 Date: Sat, 20 Jun 2026 15:34:47 +0200 Subject: [PATCH] server: avoid forwarding auth headers in CORS proxy (#24373) * server: avoid forwarding auth headers in CORS proxy * format * fix test * fix e2e test --------- Co-authored-by: Xuan Son Nguyen --- tools/server/server-cors-proxy.h | 22 ++++++- tools/server/tests/unit/test_security.py | 45 ++++++++++++++ tools/ui/src/lib/constants/mcp.ts | 3 + tools/ui/src/lib/services/mcp.service.ts | 23 +++++-- tools/ui/src/lib/utils/api-headers.ts | 15 ++++- tools/ui/src/lib/utils/cors-proxy.ts | 8 ++- tools/ui/tests/e2e/pwa.e2e.ts | 10 +-- tools/ui/tests/unit/mcp-service.test.ts | 64 +++++++++++++++++++- tools/ui/tests/unit/sanitize-headers.test.ts | 18 ++++++ 9 files changed, 187 insertions(+), 21 deletions(-) diff --git a/tools/server/server-cors-proxy.h b/tools/server/server-cors-proxy.h index 2af0c7e1c2..53a6909ed2 100644 --- a/tools/server/server-cors-proxy.h +++ b/tools/server/server-cors-proxy.h @@ -7,9 +7,18 @@ #include #include #include +#include +#include #include "server-http.h" +static std::string proxy_header_to_lower(std::string header) { + std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) { + return std::tolower(c); + }); + return header; +} + static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) { std::string target_url = req.get_param("url"); common_http_url parsed_url = common_http_parse_url(target_url); @@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str()); std::map headers; + const std::string proxy_header_prefix = "x-llama-server-proxy-header-"; for (auto [key, value] : req.headers) { - auto new_key = key; - if (string_starts_with(new_key, "x-proxy-header-")) { - string_replace_all(new_key, "x-proxy-header-", ""); + const std::string lowered_key = proxy_header_to_lower(key); + if (!string_starts_with(lowered_key, proxy_header_prefix)) { + continue; } + + auto new_key = key.substr(proxy_header_prefix.size()); + if (new_key.empty()) { + continue; + } + headers[new_key] = value; } diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 02d0b1afbc..a0c3e214ae 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -1,6 +1,8 @@ import pytest from openai import OpenAI from utils import * +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer server = ServerPreset.tinyllama2() @@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str): assert res.headers[cors_header] == cors_header_value +def test_cors_proxy_only_forwards_explicit_proxy_headers(): + class CaptureHeadersHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.server.captured_headers = dict(self.headers) + self.send_response(200) + self.end_headers() + self.wfile.write(b"ok") + + def log_message(self, format, *args): + pass + + target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler) + target.captured_headers = {} + target_thread = threading.Thread(target=target.serve_forever, daemon=True) + target_thread.start() + + try: + server = ServerPreset.tinyllama2() + server.api_key = TEST_API_KEY + server.ui_mcp_proxy = True + server.start() + + res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + "Proxy-Authorization": "Basic secret", + "X-Api-Key": TEST_API_KEY, + "Cookie": "session=secret", + "x-llama-server-proxy-header-accept": "application/json", + "x-llama-server-proxy-header-authorization": "Bearer explicit", + }) + + assert res.status_code == 200 + captured = {key.lower(): value for key, value in target.captured_headers.items()} + assert captured["accept"] == "application/json" + assert captured["authorization"] == "Bearer explicit" + assert "proxy-authorization" not in captured + assert "x-api-key" not in captured + assert "cookie" not in captured + finally: + target.shutdown() + target.server_close() + + @pytest.mark.parametrize( "media_path, image_url, success", [ diff --git a/tools/ui/src/lib/constants/mcp.ts b/tools/ui/src/lib/constants/mcp.ts index 5b11f989e2..a7381df0bf 100644 --- a/tools/ui/src/lib/constants/mcp.ts +++ b/tools/ui/src/lib/constants/mcp.ts @@ -51,6 +51,9 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2; /** CORS proxy URL query parameter name */ export const CORS_PROXY_URL_PARAM = 'url'; +/** Header prefix for headers that should be forwarded by the CORS proxy */ +export const CORS_PROXY_HEADER_PREFIX = 'x-llama-server-proxy-header-'; + /** Number of trailing characters to keep visible when partially redacting mcp-session-id */ export const MCP_SESSION_ID_VISIBLE_CHARS = 5; diff --git a/tools/ui/src/lib/services/mcp.service.ts b/tools/ui/src/lib/services/mcp.service.ts index 0aa58dc5d8..90de0d5d88 100644 --- a/tools/ui/src/lib/services/mcp.service.ts +++ b/tools/ui/src/lib/services/mcp.service.ts @@ -16,6 +16,7 @@ import { DEFAULT_MCP_CONFIG, DEFAULT_CLIENT_VERSION, DEFAULT_IMAGE_MIME_TYPE, + CORS_PROXY_HEADER_PREFIX, MCP_PARTIAL_REDACT_HEADERS, CORS_PROXY_ENDPOINT } from '$lib/constants'; @@ -133,6 +134,20 @@ export class MCPService { return details; } + private static addRequestHeaders( + requestHeaders: Headers, + headers: HeadersInit, + useProxy: boolean + ) { + for (const [key, value] of new Headers(headers).entries()) { + const proxiedKey = + useProxy && !key.toLowerCase().startsWith(CORS_PROXY_HEADER_PREFIX) + ? `${CORS_PROXY_HEADER_PREFIX}${key}` + : key; + requestHeaders.set(proxiedKey, value); + } + } + private static summarizeError(error: unknown): Record { if (error instanceof Error) { return { @@ -271,15 +286,11 @@ export class MCPService { const requestHeaders = new Headers(baseInit.headers); if (typeof Request !== 'undefined' && input instanceof Request) { - for (const [key, value] of input.headers.entries()) { - requestHeaders.set(key, value); - } + this.addRequestHeaders(requestHeaders, input.headers, useProxy); } if (init?.headers) { - for (const [key, value] of new Headers(init.headers).entries()) { - requestHeaders.set(key, value); - } + this.addRequestHeaders(requestHeaders, init.headers, useProxy); } const request = this.createDiagnosticRequestDetails( diff --git a/tools/ui/src/lib/utils/api-headers.ts b/tools/ui/src/lib/utils/api-headers.ts index c0a5309b99..a2b70d492a 100644 --- a/tools/ui/src/lib/utils/api-headers.ts +++ b/tools/ui/src/lib/utils/api-headers.ts @@ -1,5 +1,5 @@ import { config } from '$lib/stores/settings.svelte'; -import { REDACTED_HEADERS } from '$lib/constants'; +import { CORS_PROXY_HEADER_PREFIX, REDACTED_HEADERS } from '$lib/constants'; import { redactValue } from './redact'; /** @@ -52,11 +52,20 @@ export function sanitizeHeaders( for (const [key, value] of normalized.entries()) { const normalizedKey = key.toLowerCase(); - const partialChars = partialRedactHeaders?.get(normalizedKey); + const unproxiedKey = normalizedKey.startsWith(CORS_PROXY_HEADER_PREFIX) + ? normalizedKey.slice(CORS_PROXY_HEADER_PREFIX.length) + : normalizedKey; + const partialChars = + partialRedactHeaders?.get(normalizedKey) ?? partialRedactHeaders?.get(unproxiedKey); if (partialChars !== undefined) { sanitized[key] = redactValue(value, partialChars); - } else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) { + } else if ( + REDACTED_HEADERS.has(normalizedKey) || + REDACTED_HEADERS.has(unproxiedKey) || + redactedHeaders.has(normalizedKey) || + redactedHeaders.has(unproxiedKey) + ) { sanitized[key] = redactValue(value); } else { sanitized[key] = value; diff --git a/tools/ui/src/lib/utils/cors-proxy.ts b/tools/ui/src/lib/utils/cors-proxy.ts index 47caf27427..1694b7dbe6 100644 --- a/tools/ui/src/lib/utils/cors-proxy.ts +++ b/tools/ui/src/lib/utils/cors-proxy.ts @@ -3,7 +3,11 @@ */ import { base } from '$app/paths'; -import { CORS_PROXY_ENDPOINT, CORS_PROXY_URL_PARAM } from '$lib/constants'; +import { + CORS_PROXY_ENDPOINT, + CORS_PROXY_HEADER_PREFIX, + CORS_PROXY_URL_PARAM +} from '$lib/constants'; /** * Build a proxied URL that routes through llama-server's CORS proxy. @@ -28,7 +32,7 @@ export function buildProxiedHeaders(headers: Record): Record = {}; for (const [key, value] of Object.entries(headers)) { - proxiedHeaders[`x-proxy-header-${key}`] = value; + proxiedHeaders[`${CORS_PROXY_HEADER_PREFIX}${key}`] = value; } return proxiedHeaders; diff --git a/tools/ui/tests/e2e/pwa.e2e.ts b/tools/ui/tests/e2e/pwa.e2e.ts index be7642b191..e21672239b 100644 --- a/tools/ui/tests/e2e/pwa.e2e.ts +++ b/tools/ui/tests/e2e/pwa.e2e.ts @@ -39,8 +39,8 @@ test.describe('PWA Service Worker', () => { const swContent = await swResponse.text(); // Precache contains SvelteKit content-hashed bundle paths - expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/); - expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/); + expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/); + expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/); expect(swContent).toMatch(/"manifest\.webmanifest"/); expect(swContent).toMatch(/"_app\/version\.json"/); expect(swContent).toMatch(/NavigationRoute/); @@ -99,8 +99,8 @@ test.describe('PWA Service Worker', () => { const html = await response.text(); // SvelteKit outputs content-hashed bundle names in _app/immutable/ - expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/); - expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/); - expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"\)/); + expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/); + expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/); + expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"\)/); }); }); diff --git a/tools/ui/tests/unit/mcp-service.test.ts b/tools/ui/tests/unit/mcp-service.test.ts index afd3bdd5cf..1f6fdda377 100644 --- a/tools/ui/tests/unit/mcp-service.test.ts +++ b/tools/ui/tests/unit/mcp-service.test.ts @@ -3,6 +3,7 @@ import { Client } from '@modelcontextprotocol/sdk/client'; import { MCPService } from '$lib/services/mcp.service'; import { MCPConnectionPhase, MCPTransportType } from '$lib/enums'; import type { MCPConnectionLog, MCPServerConfig } from '$lib/types'; +import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants'; type DiagnosticFetchFactory = ( serverName: string, @@ -16,11 +17,12 @@ type DiagnosticFetchFactory = ( const createDiagnosticFetch = ( config: MCPServerConfig, onLog?: (log: MCPConnectionLog) => void, - baseInit: RequestInit = {} + baseInit: RequestInit = {}, + useProxy = false ) => ( MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory } - ).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog); + ).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), useProxy, onLog); describe('MCPService', () => { afterEach(() => { @@ -94,6 +96,64 @@ describe('MCPService', () => { }); }); + it('wraps dynamic request headers when using the CORS proxy', async () => { + const logs: MCPConnectionLog[] = []; + const proxiedAuthToken = `${CORS_PROXY_HEADER_PREFIX}x-auth-token`; + const proxiedContentType = `${CORS_PROXY_HEADER_PREFIX}content-type`; + const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`; + const response = new Response('{}', { + status: 200, + headers: { 'content-type': 'application/json' } + }); + const fetchMock = vi.fn().mockResolvedValue(response); + + vi.stubGlobal('fetch', fetchMock); + + const config: MCPServerConfig = { + url: 'https://example.com/mcp', + transport: MCPTransportType.STREAMABLE_HTTP, + useProxy: true + }; + + const controller = createDiagnosticFetch( + config, + (log) => logs.push(log), + { + headers: { + authorization: 'Bearer llama-server-key', + [proxiedAuthToken]: 'target-token' + } + }, + true + ); + + await controller.fetch('http://localhost:8080/cors-proxy?url=https%3A%2F%2Fexample.com%2Fmcp', { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'mcp-session-id': 'session-request-12345' + }, + body: '{}' + }); + + const sentHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Headers; + expect(sentHeaders.get('authorization')).toBe('Bearer llama-server-key'); + expect(sentHeaders.get(proxiedAuthToken)).toBe('target-token'); + expect(sentHeaders.get(proxiedContentType)).toBe('application/json'); + expect(sentHeaders.get(proxiedSessionId)).toBe('session-request-12345'); + expect(sentHeaders.has('content-type')).toBe(false); + expect(sentHeaders.has('mcp-session-id')).toBe(false); + expect(logs[0].details).toMatchObject({ + request: { + headers: { + authorization: '[redacted]', + [proxiedAuthToken]: '[redacted]', + [proxiedSessionId]: '....12345' + } + } + }); + }); + it('partially redacts mcp-session-id in diagnostic request and response logs', async () => { const logs: MCPConnectionLog[] = []; const response = new Response('{}', { diff --git a/tools/ui/tests/unit/sanitize-headers.test.ts b/tools/ui/tests/unit/sanitize-headers.test.ts index f5a682d863..8cc1fcdfc8 100644 --- a/tools/ui/tests/unit/sanitize-headers.test.ts +++ b/tools/ui/tests/unit/sanitize-headers.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from 'vitest'; import { sanitizeHeaders } from '$lib/utils/api-headers'; +import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants'; describe('sanitizeHeaders', () => { it('returns empty object for undefined input', () => { @@ -52,4 +53,21 @@ describe('sanitizeHeaders', () => { const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']); expect(result['x-custom-token']).toBe('[redacted]'); }); + + it('redacts proxied sensitive and custom target headers', () => { + const proxiedAuthorization = `${CORS_PROXY_HEADER_PREFIX}authorization`; + const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`; + const proxiedVendorKey = `${CORS_PROXY_HEADER_PREFIX}x-vendor-key`; + const headers = new Headers({ + [proxiedAuthorization]: 'Bearer secret', + [proxiedSessionId]: 'session-12345', + [proxiedVendorKey]: 'vendor-secret' + }); + const partial = new Map([['mcp-session-id', 5]]); + const result = sanitizeHeaders(headers, ['x-vendor-key'], partial); + + expect(result[proxiedAuthorization]).toBe('[redacted]'); + expect(result[proxiedSessionId]).toBe('....12345'); + expect(result[proxiedVendorKey]).toBe('[redacted]'); + }); });