diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 6f24f83ef3..8f25743fcf 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -4,7 +4,9 @@ #include +#include #include +#include #include #include @@ -51,11 +53,51 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp SRV_DBG("response: %s\n", res.body.c_str()); } +// For Google Cloud Platform deployment compatibility +struct gcp_params { + bool enabled; + std::string path_health; + std::string path_predict; + int port; + + // Ref: https://docs.cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables + gcp_params() { + enabled = getenv("AIP_MODE", "") == "PREDICTION"; + path_health = getenv("AIP_HEALTH_ROUTE", "", true); // default: using the route defined in server.cpp + path_predict = getenv("AIP_PREDICT_ROUTE", "/predict", true); + port = std::stoi(getenv("AIP_HTTP_PORT", "8080")); + } + + static std::string getenv(const char * name, const std::string & default_value, bool ensure_leading_slash = false) { + const char * value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return default_value; + } + std::string val = value; + if (ensure_leading_slash && !val.empty() && val[0] != '/') { + val.insert(val.begin(), '/'); + } + return val; + } +}; + bool server_http_context::init(const common_params & params) { + const gcp_params gcp; + path_prefix = params.api_prefix; port = params.port; hostname = params.hostname; + if (gcp.enabled) { + LOG_INF("%s: Google Cloud Platform compat: health route = %s, predict route = %s, port = %d\n", __func__, gcp.path_health.c_str(), gcp.path_predict.c_str(), gcp.port); + + if (port != gcp.port) { + LOG_WRN("%s: Google Cloud Platform compat: overriding server port %d with AIP_HTTP_PORT %d\n", __func__, port, gcp.port); + } + + port = gcp.port; + } + auto & srv = pimpl->srv; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -420,6 +462,7 @@ static void process_handler_response(server_http_req_ptr && request, server_http } void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { + handlers.emplace(path, handler); pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), @@ -436,6 +479,7 @@ void server_http_context::get(const std::string & path, const server_http_contex } void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { + handlers.emplace(path, handler); pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { std::string body = req.body; std::map files; @@ -481,3 +525,176 @@ void server_http_context::post(const std::string & path, const server_http_conte }); } +// +// Vertex AI Prediction protocol (AIP_PREDICT_ROUTE) +// https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements +// + +// Derives the camelCase @requestFormat alias for a registered path. +// e.g. "/v1/chat/completions" -> "chatCompletions", "/apply-template" -> "applyTemplate" +static std::string path_to_gcp_format(const std::string & path) { + std::string s = path; + if (s.size() > 3 && s[0] == '/' && s[1] == 'v' && s[2] == '1') { + s = s.substr(3); + } + if (!s.empty() && s[0] == '/') { + s = s.substr(1); + } + std::string result; + bool cap = false; + for (unsigned char c : s) { + if (c == ':') break; // stop before path parameters + if (c == '/' || c == '-' || c == '_') { + cap = true; + } else { + result += cap ? (char)std::toupper(c) : (char)c; + cap = false; + } + } + return result; +} + +static json parse_gcp_predict_response(const server_http_res_ptr & res) { + if (res == nullptr) { + throw std::runtime_error("empty response from internal handler"); + } + if (res->is_stream()) { + throw std::invalid_argument("predict route does not support streaming responses"); + } + if (res->data.empty()) { + return nullptr; + } + try { + return json::parse(res->data); + } catch (...) { + return res->data; + } +} + +void server_http_context::register_gcp_compat() { + const gcp_params gcp; + + if (!gcp.enabled) { + // do nothing + return; + } + + if (handlers.count(gcp.path_predict)) { + LOG_ERR("%s: AIP_PREDICT_ROUTE=%s conflicts with an existing llama-server route\n", __func__, gcp.path_predict.c_str()); + exit(1); + } + + // camelCase alias -> canonical path (first registration wins on collision) + // e.g. "chatCompletions" -> "/v1/chat/completions" + std::unordered_map alias_to_path; + for (const auto & [path, _] : handlers) { + alias_to_path.emplace(path_to_gcp_format(path), path); + } + + if (!gcp.path_health.empty()) { + auto health_handler = handlers.find("/health"); + GGML_ASSERT(health_handler != handlers.end()); + get(gcp.path_health, health_handler->second); + } + + post(gcp.path_predict, [this, alias_to_path = std::move(alias_to_path)](const server_http_req & req) -> server_http_res_ptr { + static const auto build_error = [](const std::string & message, error_type type) -> json { + return json {{"error", format_error_response(message, type)}}; + }; + + json data; + try { + data = json::parse(req.body); + } catch (const std::exception & e) { + auto res = std::make_unique(); + res->status = 400; + res->data = safe_json_to_str({{"error", format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)}}); + return res; + } + if (!data.is_object()) { + auto res = std::make_unique(); + res->status = 400; + res->data = safe_json_to_str({{"error", format_error_response("request body must be a JSON object", ERROR_TYPE_INVALID_REQUEST)}}); + return res; + } + if (!data.contains("instances") || !data.at("instances").is_array()) { + auto res = std::make_unique(); + res->status = 400; + res->data = safe_json_to_str({{"error", format_error_response("request body must include an array field named instances", ERROR_TYPE_INVALID_REQUEST)}}); + return res; + } + + const json & instances = data.at("instances"); + static const size_t MAX_INSTANCES = 128; + if (instances.size() > MAX_INSTANCES) { + auto res = std::make_unique(); + res->status = 400; + res->data = safe_json_to_str({{"error", format_error_response("instances array exceeds maximum size of " + std::to_string(MAX_INSTANCES), ERROR_TYPE_INVALID_REQUEST)}}); + return res; + } + + std::vector> futures; + futures.reserve(instances.size()); + + for (const auto & instance : instances) { + futures.push_back(std::async(std::launch::async, [this, &req, &alias_to_path, instance]() -> json { + if (!instance.is_object()) { + return build_error("each instance must be a JSON object", ERROR_TYPE_INVALID_REQUEST); + } + if (!instance.contains("@requestFormat") || !instance.at("@requestFormat").is_string()) { + return build_error("each instance must include a string @requestFormat", ERROR_TYPE_INVALID_REQUEST); + } + + try { + json payload = instance; + const std::string format = payload.at("@requestFormat").get(); + payload.erase("@requestFormat"); + + if (payload.contains("stream")) { + LOG_WRN("%s: ignoring client-provided stream field in instance, streaming is not supported in predict route\n", __func__); + payload["stream"] = false; + } + + // accept both camelCase aliases (e.g. "chatCompletions") and direct paths + std::string dispatch_path; + auto it_alias = alias_to_path.find(format); + if (it_alias != alias_to_path.end()) { + dispatch_path = it_alias->second; + } else if (handlers.count(format)) { + dispatch_path = format; + } else { + return build_error("no handler registered for @requestFormat: " + format, ERROR_TYPE_INVALID_REQUEST); + } + + const server_http_req internal_req { + req.params, + req.headers, + path_prefix + dispatch_path, + req.query_string, + payload.dump(), + {}, + req.should_stop, + }; + + server_http_res_ptr internal_res = handlers.at(dispatch_path)(internal_req); + return parse_gcp_predict_response(internal_res); + } catch (const std::invalid_argument & e) { + return build_error(e.what(), ERROR_TYPE_INVALID_REQUEST); + } catch (const std::exception & e) { + return build_error(e.what(), ERROR_TYPE_SERVER); + } catch (...) { + return build_error("unknown error", ERROR_TYPE_SERVER); + } + })); + } + + json predictions = json::array(); + for (auto & future : futures) { + predictions.push_back(future.get()); + } + + auto res = std::make_unique(); + res->data = safe_json_to_str({{"predictions", predictions}}); + return res; + }); +} diff --git a/tools/server/server-http.h b/tools/server/server-http.h index d4d3b6e536..66ee555f50 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -67,6 +67,10 @@ struct server_http_context { std::thread thread; // server thread std::atomic is_ready = false; + // note: the handler should never throw exceptions + using handler_t = std::function; + mutable std::unordered_map handlers; + std::string path_prefix; std::string hostname; int port; @@ -78,12 +82,13 @@ struct server_http_context { bool start(); void stop() const; - // note: the handler should never throw exceptions - using handler_t = std::function; - void get(const std::string & path, const handler_t & handler) const; void post(const std::string & path, const handler_t & handler) const; + // Register the Google Cloud Platform (Vertex AI) compat (AIP_PREDICT_ROUTE env var, or /predict) + // Must be called AFTER all other API routes are registered + void register_gcp_compat(); + // for debugging std::string listening_address; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 77fb7b23ba..371b13c44a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -204,6 +204,10 @@ int main(int argc, char ** argv) { // Save & load slots ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + + // Google Cloud Platform (Vertex AI) compat + ctx_http.register_gcp_compat(); + // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) if (params.webui_mcp_proxy) { SRV_WRN("%s", "-----------------\n"); diff --git a/tools/server/tests/unit/test_compat_gcp.py b/tools/server/tests/unit/test_compat_gcp.py new file mode 100644 index 0000000000..aba67bb353 --- /dev/null +++ b/tools/server/tests/unit/test_compat_gcp.py @@ -0,0 +1,60 @@ +import pytest +from utils import * + +server: ServerProcess + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.gcp_compat = True + + +def test_gcp_predict_camel_case(): + global server + server.start() + res = server.make_request("POST", "/predict", data={ + "instances": [ + { + "@requestFormat": "chatCompletions", + "max_tokens": 8, + "messages": [ + {"role": "user", "content": "What is the meaning of life?"}, + ], + } + ], + }) + assert res.status_code == 200 + assert "predictions" in res.body + assert len(res.body["predictions"]) == 1 + prediction = res.body["predictions"][0] + assert "choices" in prediction + assert len(prediction["choices"]) == 1 + assert prediction["choices"][0]["message"]["role"] == "assistant" + assert len(prediction["choices"][0]["message"]["content"]) > 0 + + +def test_gcp_predict_multiple_instances(): + global server + server.n_slots = 2 + server.start() + res = server.make_request("POST", "/predict", data={ + "instances": [ + { + "@requestFormat": "chatCompletions", + "max_tokens": 8, + "messages": [{"role": "user", "content": "Say hello"}], + }, + { + "@requestFormat": "chatCompletions", + "max_tokens": 8, + "messages": [{"role": "user", "content": "Say world"}], + }, + ], + }) + assert res.status_code == 200 + assert len(res.body["predictions"]) == 2 + for prediction in res.body["predictions"]: + assert "choices" in prediction + assert len(prediction["choices"][0]["message"]["content"]) > 0 diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 15f9bd95d7..ce93903872 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -108,6 +108,7 @@ class ServerProcess: no_cache_idle_slots: bool = False log_path: str | None = None webui_mcp_proxy: bool = False + gcp_compat: bool = False # session variables process: subprocess.Popen | None = None @@ -122,6 +123,9 @@ class ServerProcess: self.external_server = "DEBUG_EXTERNAL" in os.environ def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None: + env = {**os.environ} + if "LLAMA_CACHE" not in os.environ: + env["LLAMA_CACHE"] = "tmp" if self.external_server: print(f"[external_server]: Assuming external server running on {self.server_host}:{self.server_port}") return @@ -248,6 +252,8 @@ class ServerProcess: server_args.append("--no-cache-idle-slots") if self.webui_mcp_proxy: server_args.append("--webui-mcp-proxy") + if self.gcp_compat: + env["AIP_MODE"] = "PREDICTION" args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") @@ -268,7 +274,7 @@ class ServerProcess: creationflags=flags, stdout=self._log, stderr=self._log if self._log != sys.stdout else sys.stdout, - env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, + env=env, ) server_instances.add(self)