mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
server: support Vertex AI compatible API (#22545)
* server: support Vertex AI compatible API * a bit safer * support other AIP_* env var * various fixes * if AIP_MODE is unset, do nothing * fix test case * fix windows build
This commit is contained in:
parent
9dcf835528
commit
29debb3a6a
@ -4,7 +4,9 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
@ -51,11 +53,51 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
||||
SRV_DBG("response: %s\n", res.body.c_str());
|
||||
}
|
||||
|
||||
// For Google Cloud Platform deployment compatibility
|
||||
struct gcp_params {
|
||||
bool enabled;
|
||||
std::string path_health;
|
||||
std::string path_predict;
|
||||
int port;
|
||||
|
||||
// Ref: https://docs.cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables
|
||||
gcp_params() {
|
||||
enabled = getenv("AIP_MODE", "") == "PREDICTION";
|
||||
path_health = getenv("AIP_HEALTH_ROUTE", "", true); // default: using the route defined in server.cpp
|
||||
path_predict = getenv("AIP_PREDICT_ROUTE", "/predict", true);
|
||||
port = std::stoi(getenv("AIP_HTTP_PORT", "8080"));
|
||||
}
|
||||
|
||||
static std::string getenv(const char * name, const std::string & default_value, bool ensure_leading_slash = false) {
|
||||
const char * value = std::getenv(name);
|
||||
if (value == nullptr || value[0] == '\0') {
|
||||
return default_value;
|
||||
}
|
||||
std::string val = value;
|
||||
if (ensure_leading_slash && !val.empty() && val[0] != '/') {
|
||||
val.insert(val.begin(), '/');
|
||||
}
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
bool server_http_context::init(const common_params & params) {
|
||||
const gcp_params gcp;
|
||||
|
||||
path_prefix = params.api_prefix;
|
||||
port = params.port;
|
||||
hostname = params.hostname;
|
||||
|
||||
if (gcp.enabled) {
|
||||
LOG_INF("%s: Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", __func__, gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port);
|
||||
|
||||
if (port != gcp.port) {
|
||||
LOG_WRN("%s: Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", __func__, port, gcp.port);
|
||||
}
|
||||
|
||||
port = gcp.port;
|
||||
}
|
||||
|
||||
auto & srv = pimpl->srv;
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
@ -420,6 +462,7 @@ static void process_handler_response(server_http_req_ptr && request, server_http
|
||||
}
|
||||
|
||||
void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
|
||||
get_params(req),
|
||||
@ -436,6 +479,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
|
||||
}
|
||||
|
||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
handlers.emplace(path, handler);
|
||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string body = req.body;
|
||||
std::map<std::string, uploaded_file> files;
|
||||
@ -481,3 +525,176 @@ void server_http_context::post(const std::string & path, const server_http_conte
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE)
|
||||
// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
|
||||
//
|
||||
|
||||
// Derives the camelCase @requestFormat alias for a registered path.
|
||||
// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate"
|
||||
static std::string path_to_gcp_format(const std::string & path) {
|
||||
std::string s = path;
|
||||
if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') {
|
||||
s = s.substr(3);
|
||||
}
|
||||
if (!s.empty() && s[0] == '/') {
|
||||
s = s.substr(1);
|
||||
}
|
||||
std::string result;
|
||||
bool cap = false;
|
||||
for (unsigned char c : s) {
|
||||
if (c == ':') break; // stop before path parameters
|
||||
if (c == '/' || c == '-' || c == '_') {
|
||||
cap = true;
|
||||
} else {
|
||||
result += cap ? (char)std::toupper(c) : (char)c;
|
||||
cap = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static json parse_gcp_predict_response(const server_http_res_ptr & res) {
|
||||
if (res == nullptr) {
|
||||
throw std::runtime_error("empty response from internal handler");
|
||||
}
|
||||
if (res->is_stream()) {
|
||||
throw std::invalid_argument("predict route does not support streaming responses");
|
||||
}
|
||||
if (res->data.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
return json::parse(res->data);
|
||||
} catch (...) {
|
||||
return res->data;
|
||||
}
|
||||
}
|
||||
|
||||
void server_http_context::register_gcp_compat() {
|
||||
const gcp_params gcp;
|
||||
|
||||
if (!gcp.enabled) {
|
||||
// do nothing
|
||||
return;
|
||||
}
|
||||
|
||||
if (handlers.count(gcp.path_predict)) {
|
||||
LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, gcp.path_predict.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// camelCase alias -> canonical path (first registration wins on collision)
|
||||
// e.g. "chatCompletions" -> "/v1/chat/completions"
|
||||
std::unordered_map<std::string, std::string> alias_to_path;
|
||||
for (const auto & [path, _] : handlers) {
|
||||
alias_to_path.emplace(path_to_gcp_format(path), path);
|
||||
}
|
||||
|
||||
if (!gcp.path_health.empty()) {
|
||||
auto health_handler = handlers.find("/health");
|
||||
GGML_ASSERT(health_handler != handlers.end());
|
||||
get(gcp.path_health, health_handler->second);
|
||||
}
|
||||
|
||||
post(gcp.path_predict, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr {
|
||||
static const auto build_error = [](const std::string & message, error_type type) -> json {
|
||||
return json {{"error", format_error_response(message, type)}};
|
||||
};
|
||||
|
||||
json data;
|
||||
try {
|
||||
data = json::parse(req.body);
|
||||
} catch (const std::exception & e) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
if (!data.is_object()) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response("request body must be a JSON object", ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
if (!data.contains("instances") || !data.at("instances").is_array()) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response("request body must include an array field named instances", ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
|
||||
const json & instances = data.at("instances");
|
||||
static const size_t MAX_INSTANCES = 128;
|
||||
if (instances.size() > MAX_INSTANCES) {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str({{"error", format_error_response("instances array exceeds maximum size of " + std::to_string(MAX_INSTANCES), ERROR_TYPE_INVALID_REQUEST)}});
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<std::future<json>> futures;
|
||||
futures.reserve(instances.size());
|
||||
|
||||
for (const auto & instance : instances) {
|
||||
futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json {
|
||||
if (!instance.is_object()) {
|
||||
return build_error("each instance must be a JSON object", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) {
|
||||
return build_error("each instance must include a string @requestFormat", ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
|
||||
try {
|
||||
json payload = instance;
|
||||
const std::string format = payload.at("@requestFormat").get<std::string>();
|
||||
payload.erase("@requestFormat");
|
||||
|
||||
if (payload.contains("stream")) {
|
||||
LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__);
|
||||
payload["stream"] = false;
|
||||
}
|
||||
|
||||
// accept both camelCase aliases (e.g. "chatCompletions") and direct paths
|
||||
std::string dispatch_path;
|
||||
auto it_alias = alias_to_path.find(format);
|
||||
if (it_alias != alias_to_path.end()) {
|
||||
dispatch_path = it_alias->second;
|
||||
} else if (handlers.count(format)) {
|
||||
dispatch_path = format;
|
||||
} else {
|
||||
return build_error("no handler registered for @requestFormat: " + format, ERROR_TYPE_INVALID_REQUEST);
|
||||
}
|
||||
|
||||
const server_http_req internal_req {
|
||||
req.params,
|
||||
req.headers,
|
||||
path_prefix + dispatch_path,
|
||||
req.query_string,
|
||||
payload.dump(),
|
||||
{},
|
||||
req.should_stop,
|
||||
};
|
||||
|
||||
server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req);
|
||||
return parse_gcp_predict_response(internal_res);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
return build_error(e.what(), ERROR_TYPE_INVALID_REQUEST);
|
||||
} catch (const std::exception & e) {
|
||||
return build_error(e.what(), ERROR_TYPE_SERVER);
|
||||
} catch (...) {
|
||||
return build_error("unknown error", ERROR_TYPE_SERVER);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
json predictions = json::array();
|
||||
for (auto & future : futures) {
|
||||
predictions.push_back(future.get());
|
||||
}
|
||||
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
res->data = safe_json_to_str({{"predictions", predictions}});
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
@ -67,6 +67,10 @@ struct server_http_context {
|
||||
std::thread thread; // server thread
|
||||
std::atomic<bool> is_ready = false;
|
||||
|
||||
// note: the handler should never throw exceptions
|
||||
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
|
||||
mutable std::unordered_map<std::string, handler_t> handlers;
|
||||
|
||||
std::string path_prefix;
|
||||
std::string hostname;
|
||||
int port;
|
||||
@ -78,12 +82,13 @@ struct server_http_context {
|
||||
bool start();
|
||||
void stop() const;
|
||||
|
||||
// note: the handler should never throw exceptions
|
||||
using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
|
||||
|
||||
void get(const std::string & path, const handler_t & handler) const;
|
||||
void post(const std::string & path, const handler_t & handler) const;
|
||||
|
||||
// Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict)
|
||||
// Must be called AFTER all other API routes are registered
|
||||
void register_gcp_compat();
|
||||
|
||||
// for debugging
|
||||
std::string listening_address;
|
||||
};
|
||||
|
||||
@ -204,6 +204,10 @@ int main(int argc, char ** argv) {
|
||||
// Save & load slots
|
||||
ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
|
||||
ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
|
||||
|
||||
// Google Cloud Platform (Vertex AI) compat
|
||||
ctx_http.register_gcp_compat();
|
||||
|
||||
// CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP)
|
||||
if (params.webui_mcp_proxy) {
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
|
||||
60
tools/server/tests/unit/test_compat_gcp.py
Normal file
60
tools/server/tests/unit/test_compat_gcp.py
Normal file
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.gcp_compat = True
|
||||
|
||||
|
||||
def test_gcp_predict_camel_case():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/predict", data={
|
||||
"instances": [
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
"max_tokens": 8,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the meaning of life?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "predictions" in res.body
|
||||
assert len(res.body["predictions"]) == 1
|
||||
prediction = res.body["predictions"][0]
|
||||
assert "choices" in prediction
|
||||
assert len(prediction["choices"]) == 1
|
||||
assert prediction["choices"][0]["message"]["role"] == "assistant"
|
||||
assert len(prediction["choices"][0]["message"]["content"]) > 0
|
||||
|
||||
|
||||
def test_gcp_predict_multiple_instances():
|
||||
global server
|
||||
server.n_slots = 2
|
||||
server.start()
|
||||
res = server.make_request("POST", "/predict", data={
|
||||
"instances": [
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
"max_tokens": 8,
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
},
|
||||
{
|
||||
"@requestFormat": "chatCompletions",
|
||||
"max_tokens": 8,
|
||||
"messages": [{"role": "user", "content": "Say world"}],
|
||||
},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["predictions"]) == 2
|
||||
for prediction in res.body["predictions"]:
|
||||
assert "choices" in prediction
|
||||
assert len(prediction["choices"][0]["message"]["content"]) > 0
|
||||
@ -108,6 +108,7 @@ class ServerProcess:
|
||||
no_cache_idle_slots: bool = False
|
||||
log_path: str | None = None
|
||||
webui_mcp_proxy: bool = False
|
||||
gcp_compat: bool = False
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
@ -122,6 +123,9 @@ class ServerProcess:
|
||||
self.external_server = "DEBUG_EXTERNAL" in os.environ
|
||||
|
||||
def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
|
||||
env = {**os.environ}
|
||||
if "LLAMA_CACHE" not in os.environ:
|
||||
env["LLAMA_CACHE"] = "tmp"
|
||||
if self.external_server:
|
||||
print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}")
|
||||
return
|
||||
@ -248,6 +252,8 @@ class ServerProcess:
|
||||
server_args.append("--no-cache-idle-slots")
|
||||
if self.webui_mcp_proxy:
|
||||
server_args.append("--webui-mcp-proxy")
|
||||
if self.gcp_compat:
|
||||
env["AIP_MODE"] = "PREDICTION"
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"tests: starting server with: {' '.join(args)}")
|
||||
@ -268,7 +274,7 @@ class ServerProcess:
|
||||
creationflags=flags,
|
||||
stdout=self._log,
|
||||
stderr=self._log if self._log != sys.stdout else sys.stdout,
|
||||
env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
|
||||
env=env,
|
||||
)
|
||||
server_instances.add(self)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user