mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server: avoid forwarding auth headers in CORS proxy (#24373)
* server: avoid forwarding auth headers in CORS proxy * format * fix test * fix e2e test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
parent
67e9fd3b74
commit
e27f308597
@ -7,9 +7,18 @@
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
|
|
||||||
#include "server-http.h"
|
#include "server-http.h"
|
||||||
|
|
||||||
|
static std::string proxy_header_to_lower(std::string header) {
|
||||||
|
std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) {
|
||||||
|
return std::tolower(c);
|
||||||
|
});
|
||||||
|
return header;
|
||||||
|
}
|
||||||
|
|
||||||
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
|
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
|
||||||
std::string target_url = req.get_param("url");
|
std::string target_url = req.get_param("url");
|
||||||
common_http_url parsed_url = common_http_parse_url(target_url);
|
common_http_url parsed_url = common_http_parse_url(target_url);
|
||||||
@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
|||||||
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
||||||
|
|
||||||
std::map<std::string, std::string> headers;
|
std::map<std::string, std::string> headers;
|
||||||
|
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
|
||||||
for (auto [key, value] : req.headers) {
|
for (auto [key, value] : req.headers) {
|
||||||
auto new_key = key;
|
const std::string lowered_key = proxy_header_to_lower(key);
|
||||||
if (string_starts_with(new_key, "x-proxy-header-")) {
|
if (!string_starts_with(lowered_key, proxy_header_prefix)) {
|
||||||
string_replace_all(new_key, "x-proxy-header-", "");
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto new_key = key.substr(proxy_header_prefix.size());
|
||||||
|
if (new_key.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
headers[new_key] = value;
|
headers[new_key] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
import threading
|
||||||
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
|
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
|||||||
assert res.headers[cors_header] == cors_header_value
|
assert res.headers[cors_header] == cors_header_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_cors_proxy_only_forwards_explicit_proxy_headers():
|
||||||
|
class CaptureHeadersHandler(BaseHTTPRequestHandler):
|
||||||
|
def do_GET(self):
|
||||||
|
self.server.captured_headers = dict(self.headers)
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"ok")
|
||||||
|
|
||||||
|
def log_message(self, format, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler)
|
||||||
|
target.captured_headers = {}
|
||||||
|
target_thread = threading.Thread(target=target.serve_forever, daemon=True)
|
||||||
|
target_thread.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
server = ServerPreset.tinyllama2()
|
||||||
|
server.api_key = TEST_API_KEY
|
||||||
|
server.ui_mcp_proxy = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={
|
||||||
|
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||||
|
"Proxy-Authorization": "Basic secret",
|
||||||
|
"X-Api-Key": TEST_API_KEY,
|
||||||
|
"Cookie": "session=secret",
|
||||||
|
"x-llama-server-proxy-header-accept": "application/json",
|
||||||
|
"x-llama-server-proxy-header-authorization": "Bearer explicit",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
captured = {key.lower(): value for key, value in target.captured_headers.items()}
|
||||||
|
assert captured["accept"] == "application/json"
|
||||||
|
assert captured["authorization"] == "Bearer explicit"
|
||||||
|
assert "proxy-authorization" not in captured
|
||||||
|
assert "x-api-key" not in captured
|
||||||
|
assert "cookie" not in captured
|
||||||
|
finally:
|
||||||
|
target.shutdown()
|
||||||
|
target.server_close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"media_path, image_url, success",
|
"media_path, image_url, success",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -51,6 +51,9 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
|
|||||||
/** CORS proxy URL query parameter name */
|
/** CORS proxy URL query parameter name */
|
||||||
export const CORS_PROXY_URL_PARAM = 'url';
|
export const CORS_PROXY_URL_PARAM = 'url';
|
||||||
|
|
||||||
|
/** Header prefix for headers that should be forwarded by the CORS proxy */
|
||||||
|
export const CORS_PROXY_HEADER_PREFIX = 'x-llama-server-proxy-header-';
|
||||||
|
|
||||||
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
||||||
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import {
|
|||||||
DEFAULT_MCP_CONFIG,
|
DEFAULT_MCP_CONFIG,
|
||||||
DEFAULT_CLIENT_VERSION,
|
DEFAULT_CLIENT_VERSION,
|
||||||
DEFAULT_IMAGE_MIME_TYPE,
|
DEFAULT_IMAGE_MIME_TYPE,
|
||||||
|
CORS_PROXY_HEADER_PREFIX,
|
||||||
MCP_PARTIAL_REDACT_HEADERS,
|
MCP_PARTIAL_REDACT_HEADERS,
|
||||||
CORS_PROXY_ENDPOINT
|
CORS_PROXY_ENDPOINT
|
||||||
} from '$lib/constants';
|
} from '$lib/constants';
|
||||||
@ -133,6 +134,20 @@ export class MCPService {
|
|||||||
return details;
|
return details;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static addRequestHeaders(
|
||||||
|
requestHeaders: Headers,
|
||||||
|
headers: HeadersInit,
|
||||||
|
useProxy: boolean
|
||||||
|
) {
|
||||||
|
for (const [key, value] of new Headers(headers).entries()) {
|
||||||
|
const proxiedKey =
|
||||||
|
useProxy && !key.toLowerCase().startsWith(CORS_PROXY_HEADER_PREFIX)
|
||||||
|
? `${CORS_PROXY_HEADER_PREFIX}${key}`
|
||||||
|
: key;
|
||||||
|
requestHeaders.set(proxiedKey, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static summarizeError(error: unknown): Record<string, unknown> {
|
private static summarizeError(error: unknown): Record<string, unknown> {
|
||||||
if (error instanceof Error) {
|
if (error instanceof Error) {
|
||||||
return {
|
return {
|
||||||
@ -271,15 +286,11 @@ export class MCPService {
|
|||||||
const requestHeaders = new Headers(baseInit.headers);
|
const requestHeaders = new Headers(baseInit.headers);
|
||||||
|
|
||||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||||
for (const [key, value] of input.headers.entries()) {
|
this.addRequestHeaders(requestHeaders, input.headers, useProxy);
|
||||||
requestHeaders.set(key, value);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (init?.headers) {
|
if (init?.headers) {
|
||||||
for (const [key, value] of new Headers(init.headers).entries()) {
|
this.addRequestHeaders(requestHeaders, init.headers, useProxy);
|
||||||
requestHeaders.set(key, value);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const request = this.createDiagnosticRequestDetails(
|
const request = this.createDiagnosticRequestDetails(
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { config } from '$lib/stores/settings.svelte';
|
import { config } from '$lib/stores/settings.svelte';
|
||||||
import { REDACTED_HEADERS } from '$lib/constants';
|
import { CORS_PROXY_HEADER_PREFIX, REDACTED_HEADERS } from '$lib/constants';
|
||||||
import { redactValue } from './redact';
|
import { redactValue } from './redact';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -52,11 +52,20 @@ export function sanitizeHeaders(
|
|||||||
|
|
||||||
for (const [key, value] of normalized.entries()) {
|
for (const [key, value] of normalized.entries()) {
|
||||||
const normalizedKey = key.toLowerCase();
|
const normalizedKey = key.toLowerCase();
|
||||||
const partialChars = partialRedactHeaders?.get(normalizedKey);
|
const unproxiedKey = normalizedKey.startsWith(CORS_PROXY_HEADER_PREFIX)
|
||||||
|
? normalizedKey.slice(CORS_PROXY_HEADER_PREFIX.length)
|
||||||
|
: normalizedKey;
|
||||||
|
const partialChars =
|
||||||
|
partialRedactHeaders?.get(normalizedKey) ?? partialRedactHeaders?.get(unproxiedKey);
|
||||||
|
|
||||||
if (partialChars !== undefined) {
|
if (partialChars !== undefined) {
|
||||||
sanitized[key] = redactValue(value, partialChars);
|
sanitized[key] = redactValue(value, partialChars);
|
||||||
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
|
} else if (
|
||||||
|
REDACTED_HEADERS.has(normalizedKey) ||
|
||||||
|
REDACTED_HEADERS.has(unproxiedKey) ||
|
||||||
|
redactedHeaders.has(normalizedKey) ||
|
||||||
|
redactedHeaders.has(unproxiedKey)
|
||||||
|
) {
|
||||||
sanitized[key] = redactValue(value);
|
sanitized[key] = redactValue(value);
|
||||||
} else {
|
} else {
|
||||||
sanitized[key] = value;
|
sanitized[key] = value;
|
||||||
|
|||||||
@ -3,7 +3,11 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { base } from '$app/paths';
|
import { base } from '$app/paths';
|
||||||
import { CORS_PROXY_ENDPOINT, CORS_PROXY_URL_PARAM } from '$lib/constants';
|
import {
|
||||||
|
CORS_PROXY_ENDPOINT,
|
||||||
|
CORS_PROXY_HEADER_PREFIX,
|
||||||
|
CORS_PROXY_URL_PARAM
|
||||||
|
} from '$lib/constants';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build a proxied URL that routes through llama-server's CORS proxy.
|
* Build a proxied URL that routes through llama-server's CORS proxy.
|
||||||
@ -28,7 +32,7 @@ export function buildProxiedHeaders(headers: Record<string, string>): Record<str
|
|||||||
const proxiedHeaders: Record<string, string> = {};
|
const proxiedHeaders: Record<string, string> = {};
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(headers)) {
|
for (const [key, value] of Object.entries(headers)) {
|
||||||
proxiedHeaders[`x-proxy-header-${key}`] = value;
|
proxiedHeaders[`${CORS_PROXY_HEADER_PREFIX}${key}`] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
return proxiedHeaders;
|
return proxiedHeaders;
|
||||||
|
|||||||
@ -39,8 +39,8 @@ test.describe('PWA Service Worker', () => {
|
|||||||
const swContent = await swResponse.text();
|
const swContent = await swResponse.text();
|
||||||
|
|
||||||
// Precache contains SvelteKit content-hashed bundle paths
|
// Precache contains SvelteKit content-hashed bundle paths
|
||||||
expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/);
|
expect(swContent).toMatch(/"_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/);
|
||||||
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/);
|
expect(swContent).toMatch(/"_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
|
||||||
expect(swContent).toMatch(/"manifest\.webmanifest"/);
|
expect(swContent).toMatch(/"manifest\.webmanifest"/);
|
||||||
expect(swContent).toMatch(/"_app\/version\.json"/);
|
expect(swContent).toMatch(/"_app\/version\.json"/);
|
||||||
expect(swContent).toMatch(/NavigationRoute/);
|
expect(swContent).toMatch(/NavigationRoute/);
|
||||||
@ -99,8 +99,8 @@ test.describe('PWA Service Worker', () => {
|
|||||||
const html = await response.text();
|
const html = await response.text();
|
||||||
|
|
||||||
// SvelteKit outputs content-hashed bundle names in _app/immutable/
|
// SvelteKit outputs content-hashed bundle names in _app/immutable/
|
||||||
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"/);
|
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"/);
|
||||||
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9-]+\.css"/);
|
expect(html).toMatch(/href="(\.\/|\/)_app\/immutable\/assets\/bundle\.[a-zA-Z0-9_-]+\.css"/);
|
||||||
expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9-]+\.js"\)/);
|
expect(html).toMatch(/import\("(\.\/|\/)_app\/immutable\/bundle\.[a-zA-Z0-9_-]+\.js"\)/);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import { Client } from '@modelcontextprotocol/sdk/client';
|
|||||||
import { MCPService } from '$lib/services/mcp.service';
|
import { MCPService } from '$lib/services/mcp.service';
|
||||||
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
|
import { MCPConnectionPhase, MCPTransportType } from '$lib/enums';
|
||||||
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
|
import type { MCPConnectionLog, MCPServerConfig } from '$lib/types';
|
||||||
|
import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants';
|
||||||
|
|
||||||
type DiagnosticFetchFactory = (
|
type DiagnosticFetchFactory = (
|
||||||
serverName: string,
|
serverName: string,
|
||||||
@ -16,11 +17,12 @@ type DiagnosticFetchFactory = (
|
|||||||
const createDiagnosticFetch = (
|
const createDiagnosticFetch = (
|
||||||
config: MCPServerConfig,
|
config: MCPServerConfig,
|
||||||
onLog?: (log: MCPConnectionLog) => void,
|
onLog?: (log: MCPConnectionLog) => void,
|
||||||
baseInit: RequestInit = {}
|
baseInit: RequestInit = {},
|
||||||
|
useProxy = false
|
||||||
) =>
|
) =>
|
||||||
(
|
(
|
||||||
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
|
MCPService as unknown as { createDiagnosticFetch: DiagnosticFetchFactory }
|
||||||
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), false, onLog);
|
).createDiagnosticFetch('test-server', config, baseInit, new URL(config.url), useProxy, onLog);
|
||||||
|
|
||||||
describe('MCPService', () => {
|
describe('MCPService', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@ -94,6 +96,64 @@ describe('MCPService', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('wraps dynamic request headers when using the CORS proxy', async () => {
|
||||||
|
const logs: MCPConnectionLog[] = [];
|
||||||
|
const proxiedAuthToken = `${CORS_PROXY_HEADER_PREFIX}x-auth-token`;
|
||||||
|
const proxiedContentType = `${CORS_PROXY_HEADER_PREFIX}content-type`;
|
||||||
|
const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`;
|
||||||
|
const response = new Response('{}', {
|
||||||
|
status: 200,
|
||||||
|
headers: { 'content-type': 'application/json' }
|
||||||
|
});
|
||||||
|
const fetchMock = vi.fn().mockResolvedValue(response);
|
||||||
|
|
||||||
|
vi.stubGlobal('fetch', fetchMock);
|
||||||
|
|
||||||
|
const config: MCPServerConfig = {
|
||||||
|
url: 'https://example.com/mcp',
|
||||||
|
transport: MCPTransportType.STREAMABLE_HTTP,
|
||||||
|
useProxy: true
|
||||||
|
};
|
||||||
|
|
||||||
|
const controller = createDiagnosticFetch(
|
||||||
|
config,
|
||||||
|
(log) => logs.push(log),
|
||||||
|
{
|
||||||
|
headers: {
|
||||||
|
authorization: 'Bearer llama-server-key',
|
||||||
|
[proxiedAuthToken]: 'target-token'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
await controller.fetch('http://localhost:8080/cors-proxy?url=https%3A%2F%2Fexample.com%2Fmcp', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'content-type': 'application/json',
|
||||||
|
'mcp-session-id': 'session-request-12345'
|
||||||
|
},
|
||||||
|
body: '{}'
|
||||||
|
});
|
||||||
|
|
||||||
|
const sentHeaders = fetchMock.mock.calls[0]?.[1]?.headers as Headers;
|
||||||
|
expect(sentHeaders.get('authorization')).toBe('Bearer llama-server-key');
|
||||||
|
expect(sentHeaders.get(proxiedAuthToken)).toBe('target-token');
|
||||||
|
expect(sentHeaders.get(proxiedContentType)).toBe('application/json');
|
||||||
|
expect(sentHeaders.get(proxiedSessionId)).toBe('session-request-12345');
|
||||||
|
expect(sentHeaders.has('content-type')).toBe(false);
|
||||||
|
expect(sentHeaders.has('mcp-session-id')).toBe(false);
|
||||||
|
expect(logs[0].details).toMatchObject({
|
||||||
|
request: {
|
||||||
|
headers: {
|
||||||
|
authorization: '[redacted]',
|
||||||
|
[proxiedAuthToken]: '[redacted]',
|
||||||
|
[proxiedSessionId]: '....12345'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
|
it('partially redacts mcp-session-id in diagnostic request and response logs', async () => {
|
||||||
const logs: MCPConnectionLog[] = [];
|
const logs: MCPConnectionLog[] = [];
|
||||||
const response = new Response('{}', {
|
const response = new Response('{}', {
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import { describe, expect, it } from 'vitest';
|
import { describe, expect, it } from 'vitest';
|
||||||
import { sanitizeHeaders } from '$lib/utils/api-headers';
|
import { sanitizeHeaders } from '$lib/utils/api-headers';
|
||||||
|
import { CORS_PROXY_HEADER_PREFIX } from '$lib/constants';
|
||||||
|
|
||||||
describe('sanitizeHeaders', () => {
|
describe('sanitizeHeaders', () => {
|
||||||
it('returns empty object for undefined input', () => {
|
it('returns empty object for undefined input', () => {
|
||||||
@ -52,4 +53,21 @@ describe('sanitizeHeaders', () => {
|
|||||||
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
|
const result = sanitizeHeaders(headers, ['X-CUSTOM-TOKEN']);
|
||||||
expect(result['x-custom-token']).toBe('[redacted]');
|
expect(result['x-custom-token']).toBe('[redacted]');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('redacts proxied sensitive and custom target headers', () => {
|
||||||
|
const proxiedAuthorization = `${CORS_PROXY_HEADER_PREFIX}authorization`;
|
||||||
|
const proxiedSessionId = `${CORS_PROXY_HEADER_PREFIX}mcp-session-id`;
|
||||||
|
const proxiedVendorKey = `${CORS_PROXY_HEADER_PREFIX}x-vendor-key`;
|
||||||
|
const headers = new Headers({
|
||||||
|
[proxiedAuthorization]: 'Bearer secret',
|
||||||
|
[proxiedSessionId]: 'session-12345',
|
||||||
|
[proxiedVendorKey]: 'vendor-secret'
|
||||||
|
});
|
||||||
|
const partial = new Map([['mcp-session-id', 5]]);
|
||||||
|
const result = sanitizeHeaders(headers, ['x-vendor-key'], partial);
|
||||||
|
|
||||||
|
expect(result[proxiedAuthorization]).toBe('[redacted]');
|
||||||
|
expect(result[proxiedSessionId]).toBe('....12345');
|
||||||
|
expect(result[proxiedVendorKey]).toBe('[redacted]');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user