server: support Vertex AI compatible API (#22545)

* server: support Vertex AI compatible API

* a bit safer

* support other AIP_* env var

* various fixes

* if AIP_MODE is unset, do nothing

* fix test case

* fix windows build
This commit is contained in:
Xuan-Son Nguyen 2026-05-08 15:23:04 +02:00 committed by GitHub
parent 9dcf835528
commit 29debb3a6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 296 additions and 4 deletions

View File

@ -4,7 +4,9 @@
#include <cpp-httplib/httplib.h>
#include <cstdlib>
#include <functional>
#include <future>
#include <string>
#include <thread>
@ -51,11 +53,51 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
SRV_DBG("response: %s\n", res.body.c_str());
}
// For Google Cloud Platform deployment compatibility
struct gcp_params {
bool enabled;
std::string path_health;
std::string path_predict;
int port;
// Ref: https://docs.cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables
gcp_params() {
enabled = getenv("AIP_MODE", "") == "PREDICTION";
path_health = getenv("AIP_HEALTH_ROUTE", "", true); // default: using the route defined in server.cpp
path_predict = getenv("AIP_PREDICT_ROUTE", "/predict", true);
port = std::stoi(getenv("AIP_HTTP_PORT", "8080"));
}
static std::string getenv(const char * name, const std::string & default_value, bool ensure_leading_slash = false) {
const char * value = std::getenv(name);
if (value == nullptr || value[0] == '\0') {
return default_value;
}
std::string val = value;
if (ensure_leading_slash && !val.empty() && val[0] != '/') {
val.insert(val.begin(), '/');
}
return val;
}
};
bool server_http_context::init(const common_params & params) {
const gcp_params gcp;
path_prefix = params.api_prefix;
port = params.port;
hostname = params.hostname;
if (gcp.enabled) {
LOG_INF("%s: Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", __func__, gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
if (port != gcp.port) {
LOG_WRN("%s: Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", __func__, port, gcp.port);
}
port = gcp.port;
}
auto & srv = pimpl->srv;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@ -420,6 +462,7 @@ static void process_handler_response(server_http_req_ptr && request, server_http
}
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
get_params(req),
@ -436,6 +479,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
}
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
handlers.emplace(path, handler);
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
std::string body = req.body;
std::map<std::string, uploaded_file> files;
@ -481,3 +525,176 @@ void server_http_context::post(const std::string & path, const server_http_conte
});
}
//
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
//
// Derives the camelCase @requestFormat alias for a registered path.
// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate"
static std::string path_to_gcp_format(const std::string & path) {
std::string s = path;
if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') {
s = s.substr(3);
}
if (!s.empty() && s[0] == '/') {
s = s.substr(1);
}
std::string result;
bool cap = false;
for (unsigned char c : s) {
if (c == ':') break; // stop before path parameters
if (c == '/' || c == '-' || c == '_') {
cap = true;
} else {
result += cap ? (char)std::toupper(c) : (char)c;
cap = false;
}
}
return result;
}
static json parse_gcp_predict_response(const server_http_res_ptr & res) {
if (res == nullptr) {
throw std::runtime_error("empty response from internal handler");
}
if (res->is_stream()) {
throw std::invalid_argument("predict route does not support streaming responses");
}
if (res->data.empty()) {
return nullptr;
}
try {
return json::parse(res->data);
} catch (...) {
return res->data;
}
}
void server_http_context::register_gcp_compat() {
const gcp_params gcp;
if (!gcp.enabled) {
// do nothing
return;
}
if (handlers.count(gcp.path_predict)) {
LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, gcp.path_predict.c_str());
exit(1);
}
// camelCase alias -> canonical path (first registration wins on collision)
// e.g. "chatCompletions" -> "/v1/chat/completions"
std::unordered_map<std::string, std::string> alias_to_path;
for (const auto & [path, _] : handlers) {
alias_to_path.emplace(path_to_gcp_format(path), path);
}
if (!gcp.path_health.empty()) {
auto health_handler = handlers.find("/health");
GGML_ASSERT(health_handler != handlers.end());
get(gcp.path_health, health_handler->second);
}
post(gcp.path_predict, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr {
static const auto build_error = [](const std::string & message, error_type type) -> json {
return json {{"error", format_error_response(message, type)}};
};
json data;
try {
data = json::parse(req.body);
} catch (const std::exception & e) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
if (!data.is_object()) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response("request body must be a JSON object", ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
if (!data.contains("instances") || !data.at("instances").is_array()) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response("request body must include an array field named instances", ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
const json & instances = data.at("instances");
static const size_t MAX_INSTANCES = 128;
if (instances.size() > MAX_INSTANCES) {
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = safe_json_to_str({{"error", format_error_response("instances array exceeds maximum size of " + std::to_string(MAX_INSTANCES), ERROR_TYPE_INVALID_REQUEST)}});
return res;
}
std::vector<std::future<json>> futures;
futures.reserve(instances.size());
for (const auto & instance : instances) {
futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json {
if (!instance.is_object()) {
return build_error("each instance must be a JSON object", ERROR_TYPE_INVALID_REQUEST);
}
if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) {
return build_error("each instance must include a string @requestFormat", ERROR_TYPE_INVALID_REQUEST);
}
try {
json payload = instance;
const std::string format = payload.at("@requestFormat").get<std::string>();
payload.erase("@requestFormat");
if (payload.contains("stream")) {
LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__);
payload["stream"] = false;
}
// accept both camelCase aliases (e.g. "chatCompletions") and direct paths
std::string dispatch_path;
auto it_alias = alias_to_path.find(format);
if (it_alias != alias_to_path.end()) {
dispatch_path = it_alias->second;
} else if (handlers.count(format)) {
dispatch_path = format;
} else {
return build_error("no handler registered for @requestFormat: " + format, ERROR_TYPE_INVALID_REQUEST);
}
const server_http_req internal_req {
req.params,
req.headers,
path_prefix + dispatch_path,
req.query_string,
payload.dump(),
{},
req.should_stop,
};
server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req);
return parse_gcp_predict_response(internal_res);
} catch (const std::invalid_argument & e) {
return build_error(e.what(), ERROR_TYPE_INVALID_REQUEST);
} catch (const std::exception & e) {
return build_error(e.what(), ERROR_TYPE_SERVER);
} catch (...) {
return build_error("unknown error", ERROR_TYPE_SERVER);
}
}));
}
json predictions = json::array();
for (auto & future : futures) {
predictions.push_back(future.get());
}
auto res = std::make_unique<server_http_res>();
res->data = safe_json_to_str({{"predictions", predictions}});
return res;
});
}

View File

@ -67,6 +67,10 @@ struct server_http_context {
std::thread thread; // server thread
std::atomic<bool> is_ready = false;
// note: the handler should never throw exceptions
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
mutable std::unordered_map<std::string, handler_t> handlers;
std::string path_prefix;
std::string hostname;
int port;
@ -78,12 +82,13 @@ struct server_http_context {
bool start();
void stop() const;
// note: the handler should never throw exceptions
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
void get(const std::string & path, const handler_t & handler) const;
void post(const std::string & path, const handler_t & handler) const;
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
// Must be called AFTER all other API routes are registered
void register_gcp_compat();
// for debugging
std::string listening_address;
};

View File

@ -204,6 +204,10 @@ int main(int argc, char ** argv) {
// Save & load slots
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
// Google Cloud Platform (Vertex AI) compat
ctx_http.register_gcp_compat();
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
if (params.webui_mcp_proxy) {
SRV_WRN("%s", "-----------------\n");

View File

@ -0,0 +1,60 @@
import pytest
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.gcp_compat = True
def test_gcp_predict_camel_case():
global server
server.start()
res = server.make_request("POST", "/predict", data={
"instances": [
{
"@requestFormat": "chatCompletions",
"max_tokens": 8,
"messages": [
{"role": "user", "content": "What is the meaning of life?"},
],
}
],
})
assert res.status_code == 200
assert "predictions" in res.body
assert len(res.body["predictions"]) == 1
prediction = res.body["predictions"][0]
assert "choices" in prediction
assert len(prediction["choices"]) == 1
assert prediction["choices"][0]["message"]["role"] == "assistant"
assert len(prediction["choices"][0]["message"]["content"]) > 0
def test_gcp_predict_multiple_instances():
global server
server.n_slots = 2
server.start()
res = server.make_request("POST", "/predict", data={
"instances": [
{
"@requestFormat": "chatCompletions",
"max_tokens": 8,
"messages": [{"role": "user", "content": "Say hello"}],
},
{
"@requestFormat": "chatCompletions",
"max_tokens": 8,
"messages": [{"role": "user", "content": "Say world"}],
},
],
})
assert res.status_code == 200
assert len(res.body["predictions"]) == 2
for prediction in res.body["predictions"]:
assert "choices" in prediction
assert len(prediction["choices"][0]["message"]["content"]) > 0

View File

@ -108,6 +108,7 @@ class ServerProcess:
no_cache_idle_slots: bool = False
log_path: str | None = None
webui_mcp_proxy: bool = False
gcp_compat: bool = False
# session variables
process: subprocess.Popen | None = None
@ -122,6 +123,9 @@ class ServerProcess:
self.external_server = "DEBUG_EXTERNAL" in os.environ
def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
env = {**os.environ}
if "LLAMA_CACHE" not in os.environ:
env["LLAMA_CACHE"] = "tmp"
if self.external_server:
print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
return
@ -248,6 +252,8 @@ class ServerProcess:
server_args.append("--no-cache-idle-slots")
if self.webui_mcp_proxy:
server_args.append("--webui-mcp-proxy")
if self.gcp_compat:
env["AIP_MODE"] = "PREDICTION"
args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
@ -268,7 +274,7 @@ class ServerProcess:
creationflags=flags,
stdout=self._log,
stderr=self._log if self._log != sys.stdout else sys.stdout,
env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
env=env,
)
server_instances.add(self)