mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
cli support router mode
Co-authored-by: Piotr Wilkin <ilintar@gmail.com>
This commit is contained in:
parent
85c58bbcd0
commit
1401fc3ca7
@ -28,7 +28,8 @@ static std::string join_path(const common_http_url & parts, const std::string &
|
||||
json cli_client::get(const std::string & path) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto res = cli.Get(join_path(parts, path));
|
||||
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
|
||||
auto res = cli.Get(join_path(parts, path_with_model));
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
@ -45,7 +46,11 @@ json cli_client::get(const std::string & path) {
|
||||
json cli_client::post(const std::string & path, const json & body) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto res = cli.Post(join_path(parts, path), body.dump(), "application/json");
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
@ -100,7 +105,11 @@ json cli_client::post_sse(const std::string & path,
|
||||
};
|
||||
|
||||
httplib::Headers headers = {{"Accept", "text/event-stream"}};
|
||||
auto res = cli.Post(join_path(parts, path), headers, body.dump(), "application/json", receiver);
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
|
||||
|
||||
if (!res) {
|
||||
if (res.error() == httplib::Error::Canceled && should_stop()) {
|
||||
@ -139,3 +148,17 @@ bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
|
||||
last_error = "aborted while waiting for the server to become ready";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> cli_client::list_models() {
|
||||
json resp = get("/v1/models");
|
||||
if (!resp.contains("data") || !resp.at("data").is_array()) {
|
||||
throw std::runtime_error("invalid response from /v1/models");
|
||||
}
|
||||
std::vector<std::string> models;
|
||||
for (const auto & m : resp.at("data")) {
|
||||
if (m.contains("id") && m.at("id").is_string()) {
|
||||
models.push_back(m.at("id").get<std::string>());
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
|
||||
@ -15,6 +15,8 @@ struct cli_client {
|
||||
std::string server_base; // base url, for example "http://127.0.0.1:8080"
|
||||
std::string last_error; // set when wait_health() fails
|
||||
|
||||
std::string model; // optional, set when the server has multiple models (router mode)
|
||||
|
||||
// simple GET request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json get(const std::string & path);
|
||||
@ -49,4 +51,6 @@ struct cli_client {
|
||||
json get_props() {
|
||||
return get("/props");
|
||||
}
|
||||
|
||||
std::vector<std::string> list_models();
|
||||
};
|
||||
|
||||
@ -68,7 +68,8 @@ bool cli_context::init() {
|
||||
|
||||
std::optional<view::spinner> spinner;
|
||||
|
||||
if (!params.server_base.empty()) {
|
||||
bool use_external_server = !params.server_base.empty();
|
||||
if (use_external_server) {
|
||||
std::string base = params.server_base;
|
||||
while (!base.empty() && base.back() == '/') {
|
||||
base.pop_back();
|
||||
@ -121,6 +122,15 @@ bool cli_context::init() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use_external_server) {
|
||||
spinner.reset();
|
||||
if (!list_and_ask_models()) {
|
||||
return false;
|
||||
}
|
||||
// restore the spinner for the next step
|
||||
spinner.emplace("Waiting for server...");
|
||||
}
|
||||
|
||||
fetch_server_props();
|
||||
|
||||
return true;
|
||||
@ -149,6 +159,44 @@ void cli_context::fetch_server_props() {
|
||||
}
|
||||
}
|
||||
|
||||
bool cli_context::list_and_ask_models() {
|
||||
auto models = client.list_models();
|
||||
std::string message = "\nAvailable models:";
|
||||
if (!models.empty()) {
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
message += "\n " + std::to_string(i + 1) + ". " + models[i];
|
||||
}
|
||||
}
|
||||
message += "\n";
|
||||
view::show_message(message);
|
||||
std::string selection;
|
||||
while (selection.empty()) {
|
||||
if (should_stop()) {
|
||||
return false;
|
||||
}
|
||||
view::user_turn user_turn;
|
||||
selection = user_turn.read_input(false, "Select model by number: ");
|
||||
if (selection.empty()) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
size_t idx = std::stoul(selection);
|
||||
if (idx > 0 && idx <= models.size()) {
|
||||
model_name = models[idx - 1];
|
||||
client.model = model_name;
|
||||
view::show_message("Selected model: " + model_name);
|
||||
break;
|
||||
}
|
||||
} catch (...) {
|
||||
// ignore
|
||||
}
|
||||
view::show_error("Invalid selection. Please enter a valid number.");
|
||||
selection.clear();
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::add_system_prompt() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
messages.push_back({
|
||||
|
||||
@ -52,6 +52,10 @@ private:
|
||||
void add_system_prompt();
|
||||
void push_user_message(const std::string & text);
|
||||
|
||||
// check if server have multiple models (router mode)
|
||||
// if yes, list them then ask; do nothing otherwise
|
||||
bool list_and_ask_models();
|
||||
|
||||
// read a file and stage it as a multimodal content part; type is one of
|
||||
// "image", "audio", "video"; returns false if the file cannot be read
|
||||
bool stage_media_file(const std::string & fname, const std::string & type);
|
||||
|
||||
@ -162,8 +162,12 @@ namespace view {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
}
|
||||
std::string read_input(bool multiline_input) {
|
||||
console::log("\n> ");
|
||||
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
|
||||
if (prompt) {
|
||||
console::log("%s", prompt);
|
||||
} else {
|
||||
console::log("\n> ");
|
||||
}
|
||||
std::string buffer;
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user