mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server: avoid forwarding auth headers in CORS proxy (#24373)
* server: avoid forwarding auth headers in CORS proxy * format * fix test * fix e2e test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
67e9fd3b74
commit
e27f308597
@ -7,9 +7,18 @@
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
|
||||
#include "server-http.h"
|
||||
|
||||
static std::string proxy_header_to_lower(std::string header) {
|
||||
std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) {
|
||||
return std::tolower(c);
|
||||
});
|
||||
return header;
|
||||
}
|
||||
|
||||
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
|
||||
std::string target_url = req.get_param("url");
|
||||
common_http_url parsed_url = common_http_parse_url(target_url);
|
||||
@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
||||
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
||||
|
||||
std::map<std::string, std::string> headers;
|
||||
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
|
||||
for (auto [key, value] : req.headers) {
|
||||
auto new_key = key;
|
||||
if (string_starts_with(new_key, "x-proxy-header-")) {
|
||||
string_replace_all(new_key, "x-proxy-header-", "");
|
||||
const std::string lowered_key = proxy_header_to_lower(key);
|
||||
if (!string_starts_with(lowered_key, proxy_header_prefix)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_key = key.substr(proxy_header_prefix.size());
|
||||
if (new_key.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
headers[new_key] = value;
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
||||
assert res.headers[cors_header] == cors_header_value
|
||||
|
||||
|
||||
def test_cors_proxy_only_forwards_explicit_proxy_headers():
|
||||
class CaptureHeadersHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.server.captured_headers = dict(self.headers)
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok")
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler)
|
||||
target.captured_headers = {}
|
||||
target_thread = threading.Thread(target=target.serve_forever, daemon=True)
|
||||
target_thread.start()
|
||||
|
||||
try:
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.api_key = TEST_API_KEY
|
||||
server.ui_mcp_proxy = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={
|
||||
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||
"Proxy-Authorization": "Basic secret",
|
||||
"X-Api-Key": TEST_API_KEY,
|
||||
"Cookie": "session=secret",
|
||||
"x-llama-server-proxy-header-accept": "application/json",
|
||||
"x-llama-server-proxy-header-authorization": "Bearer explicit",
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
captured = {key.lower(): value for key, value in target.captured_headers.items()}
|
||||
assert captured["accept"] == "application/json"
|
||||
assert captured["authorization"] == "Bearer explicit"
|
||||
assert "proxy-authorization" not in captured
|
||||
assert "x-api-key" not in captured
|
||||
assert "cookie" not in captured
|
||||
finally:
|
||||
target.shutdown()
|
||||
target.server_close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"media_path, image_url, success",
|
||||
[
|
||||
|
||||
@ -51,6 +51,9 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
|
||||
/** CORS proxy URL query parameter name */
|
||||
export const CORS_PROXY_URL_PARAM = 'url';
|
||||
|
||||
/** Header prefix for headers that should be forwarded by the CORS proxy */
|
||||
export const CORS_PROXY_HEADER_PREFIX = 'x-llama-server-proxy-header-';
|
||||
|
||||
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
||||
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ import {
|
||||
DEFAULT_MCP_CONFIG,
|
||||
DEFAULT_CLIENT_VERSION,
|
||||
DEFAULT_IMAGE_MIME_TYPE,
|
||||
CORS_PROXY_HEADER_PREFIX,
|
||||
MCP_PARTIAL_REDACT_HEADERS,
|
||||
CORS_PROXY_ENDPOINT
|
||||
} from '$lib/constants';
|
||||
@ -133,6 +134,20 @@ export class MCPService {
|
||||
return details;
|
||||
}
|
||||
|
||||
private static addRequestHeaders(
|
||||
requestHeaders: Headers,
|
||||
headers: HeadersInit,
|
||||
useProxy: boolean
|
||||
) {
|
||||
for (const [key, value] of new Headers(headers).entries()) {
|
||||
const proxiedKey =
|
||||
useProxy && !key.toLowerCase().startsWith(CORS_PROXY_HEADER_PREFIX)
|
||||
? `${CORS_PROXY_HEADER_PREFIX}${key}`
|
||||
: key;
|
||||
requestHeaders.set(proxiedKey, value);
|
||||
}
|
||||
}
|
||||
|
||||
private static summarizeError(error: unknown): Record<string, unknown> {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
@ -271,15 +286,11 @@ export class MCPService {
|
||||
const requestHeaders = new Headers(baseInit.headers);
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
for (const [key, value] of input.headers.entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
this.addRequestHeaders(requestHeaders, input.headers, useProxy);
|
||||
}
|
||||
|
||||
if (init?.headers) {
|
||||
for (const [key, value] of new Headers(init.headers).entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
this.addRequestHeaders(requestHeaders, init.headers, useProxy);
|
||||
}
|
||||
|
||||
const request = this.createDiagnosticRequestDetails(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { REDACTED_HEADERS } from '$lib/constants';
|
||||
import { CORS_PROXY_HEADER_PREFIX, REDACTED_HEADERS } from '$lib/constants';
|
||||
import { redactValue } from './redact';
|
||||
|
||||
/**
|
||||
@ -52,11 +52,20 @@ export function sanitizeHeaders(
|
||||
|
||||
for (const [key, value] of normalized.entries()) {
|
||||
const normalizedKey = key.toLowerCase();
|
||||
const partialChars = partialRedactHeaders?.get(normalizedKey);
|
||||
const unproxiedKey = normalizedKey.startsWith(CORS_PROXY_HEADER_PREFIX)
|
||||
? normalizedKey.slice(CORS_PROXY_HEADER_PREFIX.length)
|
||||
: normalizedKey;
|
||||
const partialChars =
|
||||
partialRedactHeaders?.get(normalizedKey) ?? partialRedactHeaders?.get(unproxiedKey);
|
||||
|
||||
if (partialChars !== undefined) {
|
||||
sanitized[key] = redactValue(value, partialChars);
|
||||
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
|
||||
} else if (
|
||||
REDACTED_HEADERS.has(normalizedKey) ||
|
||||
REDACTED_HEADERS.has(unproxiedKey) ||
|
||||
redactedHeaders.has(normalizedKey) ||
|
||||
redactedHeaders.has(unproxiedKey)
|
||||
) {
|
||||
sanitized[key] = redactValue(value);
|
||||
} else {
|
||||
sanitized[key] = value;
|
||||
|
||||
@ -3,7 +3,11 @@
|
||||
*/
|
||||
|
||||
import { base } from '$app/paths';
|
||||
import { CORS_PROXY_ENDPOINT, CORS_PROXY_URL_PARAM } from '$lib/constants';
|
||||
import {
|
||||
CORS_PROXY_ENDPOINT,
|
||||
CORS_PROXY_HEADER_PREFIX,
|
||||
CORS_PROXY_URL_PARAM
|
||||
} from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Build a proxied URL that routes through llama-server's CORS proxy.
|
||||
@ -28,7 +32,7 @@ export function buildProxiedHeaders(headers: Record<string, string>): Record<str
|
||||
const proxiedHeaders: Record<string, string> = {};
|
||||
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
proxiedHeaders[`x-proxy-header-${key}`] = value;
|
||||
proxiedHeaders[`${CORS_PROXY_HEADER_PREFIX}${key}`] = value;
|
||||
}
|
||||
|
||||
return proxiedHeaders;
|
||||
|
||||
@ -39,8 +39,8 @@ test.describe('PWA Service Worker', () => {
|
||||
const swContent = await swResponse.text();
|
||||
|
||||
// Precache contains SvelteKit content-hashed bundle paths
|
||||
expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/);
|
||||
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/);
|
||||
expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/);
|
||||
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
|
||||
expect(swContent).toMatch(/"manifest\.webmanifest"/);
|
||||
expect(swContent).toMatch(/"_app\/version\.json"/);
|
||||
expect(swContent).toMatch(/NavigationRoute/);
|
||||
@ -99,8 +99,8 @@ test.describe('PWA Service Worker', () => {
|
||||
const html = await response.text();
|
||||
|
||||
// SvelteKit outputs content-hashed bundle names in _app/immutable/
|
||||
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/);
|
||||
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/);
|
||||
expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"\)/);
|
||||
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/);
|
||||
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
|
||||
expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"\)/);
|
||||
});
|
||||
});
|
||||
|
||||
@ -3,6 +3,7 @@ import { Client } from '@modelcontextprotocol/sdk/client';
|
||||
import { MCPService } from '$lib/services/mcp.service';
|
||||
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
|
||||
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
|
||||
import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants';
|
||||
|
||||
type DiagnosticFetchFactory = (
|
||||
serverName: string,
|
||||
@ -16,11 +17,12 @@ type DiagnosticFetchFactory = (
|
||||
const createDiagnosticFetch = (
|
||||
config: MCPServerConfig,
|
||||
onLog?: (log: MCPConnectionLog) => void,
|
||||
baseInit: RequestInit = {}
|
||||
baseInit: RequestInit = {},
|
||||
useProxy = false
|
||||
) =>
|
||||
(
|
||||
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
|
||||
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
|
||||
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), useProxy, onLog);
|
||||
|
||||
describe('MCPService', () => {
|
||||
afterEach(() => {
|
||||
@ -94,6 +96,64 @@ describe('MCPService', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('wraps dynamic request headers when using the CORS proxy', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const proxiedAuthToken = `${CORS_PROXY_HEADER_PREFIX}x-auth-token`;
|
||||
const proxiedContentType = `${CORS_PROXY_HEADER_PREFIX}content-type`;
|
||||
const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`;
|
||||
const response = new Response('{}', {
|
||||
status: 200,
|
||||
headers: { 'content-type': 'application/json' }
|
||||
});
|
||||
const fetchMock = vi.fn().mockResolvedValue(response);
|
||||
|
||||
vi.stubGlobal('fetch', fetchMock);
|
||||
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://example.com/mcp',
|
||||
transport: MCPTransportType.STREAMABLE_HTTP,
|
||||
useProxy: true
|
||||
};
|
||||
|
||||
const controller = createDiagnosticFetch(
|
||||
config,
|
||||
(log) => logs.push(log),
|
||||
{
|
||||
headers: {
|
||||
authorization: 'Bearer llama-server-key',
|
||||
[proxiedAuthToken]: 'target-token'
|
||||
}
|
||||
},
|
||||
true
|
||||
);
|
||||
|
||||
await controller.fetch('http://localhost:8080/cors-proxy?url=https%3A%2F%2Fexample.com%2Fmcp', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'content-type': 'application/json',
|
||||
'mcp-session-id': 'session-request-12345'
|
||||
},
|
||||
body: '{}'
|
||||
});
|
||||
|
||||
const sentHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Headers;
|
||||
expect(sentHeaders.get('authorization')).toBe('Bearer llama-server-key');
|
||||
expect(sentHeaders.get(proxiedAuthToken)).toBe('target-token');
|
||||
expect(sentHeaders.get(proxiedContentType)).toBe('application/json');
|
||||
expect(sentHeaders.get(proxiedSessionId)).toBe('session-request-12345');
|
||||
expect(sentHeaders.has('content-type')).toBe(false);
|
||||
expect(sentHeaders.has('mcp-session-id')).toBe(false);
|
||||
expect(logs[0].details).toMatchObject({
|
||||
request: {
|
||||
headers: {
|
||||
authorization: '[redacted]',
|
||||
[proxiedAuthToken]: '[redacted]',
|
||||
[proxiedSessionId]: '....12345'
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
|
||||
const logs: MCPConnectionLog[] = [];
|
||||
const response = new Response('{}', {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { sanitizeHeaders } from '$lib/utils/api-headers';
|
||||
import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants';
|
||||
|
||||
describe('sanitizeHeaders', () => {
|
||||
it('returns empty object for undefined input', () => {
|
||||
@ -52,4 +53,21 @@ describe('sanitizeHeaders', () => {
|
||||
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
|
||||
expect(result['x-custom-token']).toBe('[redacted]');
|
||||
});
|
||||
|
||||
it('redacts proxied sensitive and custom target headers', () => {
|
||||
const proxiedAuthorization = `${CORS_PROXY_HEADER_PREFIX}authorization`;
|
||||
const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`;
|
||||
const proxiedVendorKey = `${CORS_PROXY_HEADER_PREFIX}x-vendor-key`;
|
||||
const headers = new Headers({
|
||||
[proxiedAuthorization]: 'Bearer secret',
|
||||
[proxiedSessionId]: 'session-12345',
|
||||
[proxiedVendorKey]: 'vendor-secret'
|
||||
});
|
||||
const partial = new Map([['mcp-session-id', 5]]);
|
||||
const result = sanitizeHeaders(headers, ['x-vendor-key'], partial);
|
||||
|
||||
expect(result[proxiedAuthorization]).toBe('[redacted]');
|
||||
expect(result[proxiedSessionId]).toBe('....12345');
|
||||
expect(result[proxiedVendorKey]).toBe('[redacted]');
|
||||
});
|
||||
});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user