server: avoid forwarding auth headers in CORS proxy (#24373)

* server: avoid forwarding auth headers in CORS proxy

* format

* fix test

* fix e2e test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Matti4 2026-06-20 15:34:47 +02:00 committed by GitHub
parent 67e9fd3b74
commit e27f308597
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 187 additions and 21 deletions

View File

@ -7,9 +7,18 @@
#include <unordered_set>
#include <list>
#include <map>
#include <algorithm>
#include <cctype>
#include "server-http.h"
static std::string proxy_header_to_lower(std::string header) {
std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) {
return std::tolower(c);
});
return header;
}
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
std::string target_url = req.get_param("url");
common_http_url parsed_url = common_http_parse_url(target_url);
@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
std::map<std::string, std::string> headers;
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
for (auto [key, value] : req.headers) {
auto new_key = key;
if (string_starts_with(new_key, "x-proxy-header-")) {
string_replace_all(new_key, "x-proxy-header-", "");
const std::string lowered_key = proxy_header_to_lower(key);
if (!string_starts_with(lowered_key, proxy_header_prefix)) {
continue;
}
auto new_key = key.substr(proxy_header_prefix.size());
if (new_key.empty()) {
continue;
}
headers[new_key] = value;
}

View File

@ -1,6 +1,8 @@
import pytest
from openai import OpenAI
from utils import *
import threading
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
server = ServerPreset.tinyllama2()
@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
assert res.headers[cors_header] == cors_header_value
def test_cors_proxy_only_forwards_explicit_proxy_headers():
class CaptureHeadersHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.server.captured_headers = dict(self.headers)
self.send_response(200)
self.end_headers()
self.wfile.write(b"ok")
def log_message(self, format, *args):
pass
target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler)
target.captured_headers = {}
target_thread = threading.Thread(target=target.serve_forever, daemon=True)
target_thread.start()
try:
server = ServerPreset.tinyllama2()
server.api_key = TEST_API_KEY
server.ui_mcp_proxy = True
server.start()
res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={
"Authorization": f"Bearer {TEST_API_KEY}",
"Proxy-Authorization": "Basic secret",
"X-Api-Key": TEST_API_KEY,
"Cookie": "session=secret",
"x-llama-server-proxy-header-accept": "application/json",
"x-llama-server-proxy-header-authorization": "Bearer explicit",
})
assert res.status_code == 200
captured = {key.lower(): value for key, value in target.captured_headers.items()}
assert captured["accept"] == "application/json"
assert captured["authorization"] == "Bearer explicit"
assert "proxy-authorization" not in captured
assert "x-api-key" not in captured
assert "cookie" not in captured
finally:
target.shutdown()
target.server_close()
@pytest.mark.parametrize(
"media_path, image_url, success",
[

View File

@ -51,6 +51,9 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
/** CORS proxy URL query parameter name */
export const CORS_PROXY_URL_PARAM = 'url';
/** Header prefix for headers that should be forwarded by the CORS proxy */
export const CORS_PROXY_HEADER_PREFIX = 'x-llama-server-proxy-header-';
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;

View File

@ -16,6 +16,7 @@ import {
DEFAULT_MCP_CONFIG,
DEFAULT_CLIENT_VERSION,
DEFAULT_IMAGE_MIME_TYPE,
CORS_PROXY_HEADER_PREFIX,
MCP_PARTIAL_REDACT_HEADERS,
CORS_PROXY_ENDPOINT
} from '$lib/constants';
@ -133,6 +134,20 @@ export class MCPService {
return details;
}
private static addRequestHeaders(
requestHeaders: Headers,
headers: HeadersInit,
useProxy: boolean
) {
for (const [key, value] of new Headers(headers).entries()) {
const proxiedKey =
useProxy && !key.toLowerCase().startsWith(CORS_PROXY_HEADER_PREFIX)
? `${CORS_PROXY_HEADER_PREFIX}${key}`
: key;
requestHeaders.set(proxiedKey, value);
}
}
private static summarizeError(error: unknown): Record<string, unknown> {
if (error instanceof Error) {
return {
@ -271,15 +286,11 @@ export class MCPService {
const requestHeaders = new Headers(baseInit.headers);
if (typeof Request !== 'undefined' && input instanceof Request) {
for (const [key, value] of input.headers.entries()) {
requestHeaders.set(key, value);
}
this.addRequestHeaders(requestHeaders, input.headers, useProxy);
}
if (init?.headers) {
for (const [key, value] of new Headers(init.headers).entries()) {
requestHeaders.set(key, value);
}
this.addRequestHeaders(requestHeaders, init.headers, useProxy);
}
const request = this.createDiagnosticRequestDetails(

View File

@ -1,5 +1,5 @@
import { config } from '$lib/stores/settings.svelte';
import { REDACTED_HEADERS } from '$lib/constants';
import { CORS_PROXY_HEADER_PREFIX, REDACTED_HEADERS } from '$lib/constants';
import { redactValue } from './redact';
/**
@ -52,11 +52,20 @@ export function sanitizeHeaders(
for (const [key, value] of normalized.entries()) {
const normalizedKey = key.toLowerCase();
const partialChars = partialRedactHeaders?.get(normalizedKey);
const unproxiedKey = normalizedKey.startsWith(CORS_PROXY_HEADER_PREFIX)
? normalizedKey.slice(CORS_PROXY_HEADER_PREFIX.length)
: normalizedKey;
const partialChars =
partialRedactHeaders?.get(normalizedKey) ?? partialRedactHeaders?.get(unproxiedKey);
if (partialChars !== undefined) {
sanitized[key] = redactValue(value, partialChars);
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
} else if (
REDACTED_HEADERS.has(normalizedKey) ||
REDACTED_HEADERS.has(unproxiedKey) ||
redactedHeaders.has(normalizedKey) ||
redactedHeaders.has(unproxiedKey)
) {
sanitized[key] = redactValue(value);
} else {
sanitized[key] = value;

View File

@ -3,7 +3,11 @@
*/
import { base } from '$app/paths';
import { CORS_PROXY_ENDPOINT, CORS_PROXY_URL_PARAM } from '$lib/constants';
import {
CORS_PROXY_ENDPOINT,
CORS_PROXY_HEADER_PREFIX,
CORS_PROXY_URL_PARAM
} from '$lib/constants';
/**
* Build a proxied URL that routes through llama-server's CORS proxy.
@ -28,7 +32,7 @@ export function buildProxiedHeaders(headers: Record<string, string>): Record<str
const proxiedHeaders: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
proxiedHeaders[`x-proxy-header-${key}`] = value;
proxiedHeaders[`${CORS_PROXY_HEADER_PREFIX}${key}`] = value;
}
return proxiedHeaders;

View File

@ -39,8 +39,8 @@ test.describe('PWA Service Worker', () => {
const swContent = await swResponse.text();
// Precache contains SvelteKit content-hashed bundle paths
expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/);
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/);
expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/);
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
expect(swContent).toMatch(/"manifest\.webmanifest"/);
expect(swContent).toMatch(/"_app\/version\.json"/);
expect(swContent).toMatch(/NavigationRoute/);
@ -99,8 +99,8 @@ test.describe('PWA Service Worker', () => {
const html = await response.text();
// SvelteKit outputs content-hashed bundle names in _app/immutable/
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/);
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/);
expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"\)/);
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/);
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"\)/);
});
});

View File

@ -3,6 +3,7 @@ import { Client } from '@modelcontextprotocol/sdk/client';
import { MCPService } from '$lib/services/mcp.service';
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants';
type DiagnosticFetchFactory = (
serverName: string,
@ -16,11 +17,12 @@ type DiagnosticFetchFactory = (
const createDiagnosticFetch = (
config: MCPServerConfig,
onLog?: (log: MCPConnectionLog) => void,
baseInit: RequestInit = {}
baseInit: RequestInit = {},
useProxy = false
) =>
(
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), useProxy, onLog);
describe('MCPService', () => {
afterEach(() => {
@ -94,6 +96,64 @@ describe('MCPService', () => {
});
});
it('wraps dynamic request headers when using the CORS proxy', async () => {
const logs: MCPConnectionLog[] = [];
const proxiedAuthToken = `${CORS_PROXY_HEADER_PREFIX}x-auth-token`;
const proxiedContentType = `${CORS_PROXY_HEADER_PREFIX}content-type`;
const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`;
const response = new Response('{}', {
status: 200,
headers: { 'content-type': 'application/json' }
});
const fetchMock = vi.fn().mockResolvedValue(response);
vi.stubGlobal('fetch', fetchMock);
const config: MCPServerConfig = {
url: 'https://example.com/mcp',
transport: MCPTransportType.STREAMABLE_HTTP,
useProxy: true
};
const controller = createDiagnosticFetch(
config,
(log) => logs.push(log),
{
headers: {
authorization: 'Bearer llama-server-key',
[proxiedAuthToken]: 'target-token'
}
},
true
);
await controller.fetch('http://localhost:8080/cors-proxy?url=https%3A%2F%2Fexample.com%2Fmcp', {
method: 'POST',
headers: {
'content-type': 'application/json',
'mcp-session-id': 'session-request-12345'
},
body: '{}'
});
const sentHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Headers;
expect(sentHeaders.get('authorization')).toBe('Bearer llama-server-key');
expect(sentHeaders.get(proxiedAuthToken)).toBe('target-token');
expect(sentHeaders.get(proxiedContentType)).toBe('application/json');
expect(sentHeaders.get(proxiedSessionId)).toBe('session-request-12345');
expect(sentHeaders.has('content-type')).toBe(false);
expect(sentHeaders.has('mcp-session-id')).toBe(false);
expect(logs[0].details).toMatchObject({
request: {
headers: {
authorization: '[redacted]',
[proxiedAuthToken]: '[redacted]',
[proxiedSessionId]: '....12345'
}
}
});
});
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
const logs: MCPConnectionLog[] = [];
const response = new Response('{}', {

View File

@ -1,5 +1,6 @@
import { describe, expect, it } from 'vitest';
import { sanitizeHeaders } from '$lib/utils/api-headers';
import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants';
describe('sanitizeHeaders', () => {
it('returns empty object for undefined input', () => {
@ -52,4 +53,21 @@ describe('sanitizeHeaders', () => {
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
expect(result['x-custom-token']).toBe('[redacted]');
});
it('redacts proxied sensitive and custom target headers', () => {
const proxiedAuthorization = `${CORS_PROXY_HEADER_PREFIX}authorization`;
const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`;
const proxiedVendorKey = `${CORS_PROXY_HEADER_PREFIX}x-vendor-key`;
const headers = new Headers({
[proxiedAuthorization]: 'Bearer secret',
[proxiedSessionId]: 'session-12345',
[proxiedVendorKey]: 'vendor-secret'
});
const partial = new Map([['mcp-session-id', 5]]);
const result = sanitizeHeaders(headers, ['x-vendor-key'], partial);
expect(result[proxiedAuthorization]).toBe('[redacted]');
expect(result[proxiedSessionId]).toBe('....12345');
expect(result[proxiedVendorKey]).toBe('[redacted]');
});
});