mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
server: support slot save/restore/erase for mtmd tokens and checkpoints (#1584)
Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
parent
0147cf4837
commit
5e8bb724ce
@ -1054,6 +1054,18 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
|
||||
return image_tokens->n_tokens();
|
||||
}
|
||||
|
||||
mtmd_input_chunk * mtmd_create_input_chunk() {
|
||||
auto * chunk = new mtmd_input_chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
std::vector<llama_token>{},
|
||||
nullptr,
|
||||
nullptr
|
||||
};
|
||||
return chunk;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// test function
|
||||
|
||||
mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
@ -1088,3 +1100,133 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
static json mtmd_clip_image_f32_to_json(const clip_image_f32 & clip) {
|
||||
json j;
|
||||
j["nx"] = clip.nx;
|
||||
j["ny"] = clip.ny;
|
||||
j["buf"] = clip.buf;
|
||||
return j;
|
||||
}
|
||||
|
||||
static clip_image_f32 * mtmd_clip_image_f32_from_json(const json & j) {
|
||||
clip_image_f32 * clip = new clip_image_f32;
|
||||
clip->nx = j["nx"];
|
||||
clip->ny = j["ny"];
|
||||
clip->buf = j["buf"].get<std::vector<float>>();
|
||||
return clip;
|
||||
}
|
||||
|
||||
static json mtmd_clip_image_f32_batch_to_json(const clip_image_f32_batch & batch, bool full = false) {
|
||||
json j;
|
||||
j["is_audio"] = batch.is_audio;
|
||||
j["grid_x"] = batch.grid_x;
|
||||
j["grid_y"] = batch.grid_y;
|
||||
|
||||
if (full) {
|
||||
std::vector<nlohmann::json> entries;
|
||||
for (auto & entry : batch.entries) {
|
||||
entries.push_back(mtmd_clip_image_f32_to_json(*entry));
|
||||
}
|
||||
j["entries"] = entries;
|
||||
}
|
||||
|
||||
return j;
|
||||
}
|
||||
|
||||
static clip_image_f32_batch mtmd_clip_image_f32_batch_from_json(const json & j, bool full = false) {
|
||||
clip_image_f32_batch batch;
|
||||
if (j.contains("is_audio")) {
|
||||
batch.is_audio = j["is_audio"];
|
||||
batch.grid_x = j["grid_x"];
|
||||
batch.grid_y = j["grid_y"];
|
||||
if (full) {
|
||||
auto entries = j["entries"];
|
||||
if (entries.is_array()) {
|
||||
for (auto & entry : entries) {
|
||||
clip_image_f32 * clip = mtmd_clip_image_f32_from_json(entry);
|
||||
batch.entries.push_back(clip_image_f32_ptr(clip));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
static mtmd_audio_tokens mtmd_audio_tokens_from_json(json & j) {
|
||||
return mtmd_audio_tokens{
|
||||
j.value<uint32_t>("n_tokens", 0),
|
||||
mtmd_clip_image_f32_batch_from_json(j.value("batch_f32", json{})),
|
||||
j.value("id","")
|
||||
};
|
||||
}
|
||||
|
||||
static mtmd_image_tokens mtmd_image_tokens_from_json(json & j) {
|
||||
return mtmd_image_tokens{
|
||||
j.value<uint32_t>("nx", 0),
|
||||
j.value<uint32_t>("ny", 0),
|
||||
j.value("use_mrope_pos",false),
|
||||
mtmd_clip_image_f32_batch_from_json(j.value("batch_f32", json{})),
|
||||
j.value("id","")
|
||||
};
|
||||
}
|
||||
|
||||
static json mtmd_audio_tokens_to_json(mtmd_audio_tokens * chunk) {
|
||||
json j;
|
||||
if (chunk) {
|
||||
j["n_tokens"] = chunk->n_tokens;
|
||||
j["id"] = chunk->id;
|
||||
j["batch_f32"] = mtmd_clip_image_f32_batch_to_json(chunk->batch_f32);
|
||||
}
|
||||
return j;
|
||||
}
|
||||
|
||||
static json mtmd_image_tokens_to_json(mtmd_image_tokens * chunk) {
|
||||
json j;
|
||||
if (chunk) {
|
||||
j["nx"] = chunk->nx;
|
||||
j["ny"] = chunk->ny;
|
||||
j["use_mrope_pos"] = chunk->use_mrope_pos;
|
||||
j["batch_f32"] = mtmd_clip_image_f32_batch_to_json(chunk->batch_f32);
|
||||
j["id"] = chunk->id;
|
||||
}
|
||||
return j;
|
||||
}
|
||||
|
||||
mtmd_input_chunk * mtmd_input_chunk_from_json(json & j) {
|
||||
mtmd_input_chunk * chunk = mtmd_create_input_chunk();
|
||||
chunk->type = j.value("type", MTMD_INPUT_CHUNK_TYPE_TEXT);
|
||||
chunk->tokens_text = j.value("tokens_text", chunk->tokens_text);
|
||||
chunk->tokens_image = nullptr;
|
||||
chunk->tokens_audio = nullptr;
|
||||
if (j.contains("tokens_image")) {
|
||||
chunk->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
|
||||
auto image_json = j.value("tokens_image", json::array());
|
||||
*chunk->tokens_image = mtmd_image_tokens_from_json(image_json);
|
||||
}
|
||||
if (j.contains("tokens_audio")) {
|
||||
chunk->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
|
||||
*chunk->tokens_audio = mtmd_audio_tokens_from_json(j.at("tokens_audio"));
|
||||
}
|
||||
return chunk;
|
||||
}
|
||||
|
||||
void mtmd_input_chunk_to_json(mtmd_input_chunk * chunk, json & j) {
|
||||
j.clear();
|
||||
if (chunk) {
|
||||
j["type"] = chunk->type;
|
||||
j["tokens_text"] = chunk->tokens_text;
|
||||
if (chunk->tokens_image) {
|
||||
j["tokens_image"] = mtmd_image_tokens_to_json(chunk->tokens_image.get());
|
||||
}
|
||||
if (chunk->tokens_audio) {
|
||||
j["tokens_audio"] = mtmd_audio_tokens_to_json(chunk->tokens_audio.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
#include <vector>
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
using json = nlohmann::ordered_json;
|
||||
#endif
|
||||
|
||||
/**
|
||||
@ -215,6 +217,9 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||
// the reading size (in bytes) is equal to:
|
||||
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
MTMD_API mtmd_input_chunk * mtmd_create_input_chunk(void);
|
||||
MTMD_API mtmd_input_chunk * mtmd_input_chunk_from_json(json & j);
|
||||
MTMD_API void mtmd_input_chunk_to_json(mtmd_input_chunk * chunk, json & j);
|
||||
|
||||
/////////////////////////////////////////
|
||||
|
||||
|
||||
@ -2164,6 +2164,47 @@ server_tokens server_tokens::clone() const {
|
||||
return res;
|
||||
}
|
||||
|
||||
json server_tokens::to_json() const
|
||||
{
|
||||
json j;
|
||||
std::vector<nlohmann::json> media_array;
|
||||
for (auto & [idx, chunk_ptr] : map_idx_to_media) { // or direct access if friend
|
||||
if (chunk_ptr) {
|
||||
nlohmann::json obj;
|
||||
obj["index"] = idx;
|
||||
json j;
|
||||
mtmd_input_chunk_to_json(chunk_ptr.get(), j);
|
||||
obj["chunk"] = j;
|
||||
media_array.push_back(std::move(obj));
|
||||
}
|
||||
}
|
||||
j = nlohmann::json{
|
||||
{"has_mtmd", has_mtmd},
|
||||
{"map_idx_to_media", media_array},
|
||||
{"tokens", tokens}
|
||||
};
|
||||
return j;
|
||||
}
|
||||
|
||||
void server_tokens::from_json(const json & j) {
|
||||
clear();
|
||||
map_idx_to_media.clear();
|
||||
has_mtmd = j.value("has_mtmd", has_mtmd);
|
||||
tokens = j.value("tokens", tokens);
|
||||
map_idx_to_media.clear();
|
||||
json media_array = j.at("map_idx_to_media");
|
||||
if (media_array.is_array()) {
|
||||
for (const auto & entry : media_array) {
|
||||
size_t idx = entry.at("index");
|
||||
json chunk_json = entry.at("chunk");
|
||||
mtmd_input_chunk * chunk = mtmd_input_chunk_from_json(chunk_json);
|
||||
map_idx_to_media[idx] = mtmd::input_chunk_ptr(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// Keep the first n_keep and remove n_discard tokens from tokens
|
||||
void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) {
|
||||
|
||||
|
||||
@ -349,6 +349,10 @@ public:
|
||||
|
||||
server_tokens(const llama_tokens& tokens, bool has_mtmd);
|
||||
|
||||
json to_json() const;
|
||||
|
||||
void from_json(const json & j);
|
||||
|
||||
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
|
||||
llama_pos pos_next(int64_t n_tokens = -1) const;
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
|
||||
static void log_text(const gpt_params & params_base, const std::string & text) {
|
||||
@ -1995,6 +1997,117 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
|
||||
if (checkpoints.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
|
||||
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
|
||||
uint32_t version = LLAMA_STATE_SEQ_VERSION;
|
||||
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
|
||||
size_t count = checkpoints.size();
|
||||
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
|
||||
|
||||
for (const auto & checkpoint : checkpoints) {
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
||||
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
||||
size_t data_len = checkpoint.data.size();
|
||||
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
|
||||
if (data_len > 0) {
|
||||
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
||||
}
|
||||
}
|
||||
size_t pos = file.tellp();
|
||||
file.close();
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
|
||||
std::ifstream file(filename, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return 0;
|
||||
}
|
||||
checkpoints.clear();
|
||||
// version checks
|
||||
{
|
||||
uint32_t magic;
|
||||
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
|
||||
uint32_t version;
|
||||
file.read(reinterpret_cast<char *>(&version), sizeof(version));
|
||||
|
||||
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
// load the checkpoints
|
||||
{
|
||||
size_t count;
|
||||
file.read(reinterpret_cast<char *>(&count), sizeof(count));
|
||||
|
||||
for (int i = 0; i < count; i++) {
|
||||
server_prompt_checkpoint checkpoint;
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
|
||||
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
|
||||
|
||||
size_t data_len;
|
||||
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
|
||||
if (data_len > 0) {
|
||||
checkpoint.data.resize(data_len);
|
||||
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
|
||||
}
|
||||
checkpoints.push_back(checkpoint);
|
||||
}
|
||||
}
|
||||
size_t pos = file.tellg();
|
||||
file.close();
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
json token_json = tokens.to_json();
|
||||
token_json["magic"] = LLAMA_SERVER_MAGIC;
|
||||
token_json["version"] = LLAMA_SERVER_VERSION;
|
||||
size_t pos = 0;
|
||||
if (file.is_open()) {
|
||||
file << token_json;
|
||||
pos = file.tellp();
|
||||
file.close();
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
|
||||
std::ifstream file(filename, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
return 0;
|
||||
}
|
||||
size_t pos = 0;
|
||||
json token_json;
|
||||
if (file.is_open()) {
|
||||
file >> token_json;
|
||||
pos = file.tellg();
|
||||
file.close();
|
||||
}
|
||||
uint32_t magic = token_json.value<uint32_t>("magic", 0);
|
||||
uint32_t version = token_json.value<uint32_t>("version", 0);
|
||||
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
|
||||
return 0;
|
||||
}
|
||||
tokens.from_json(token_json);
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
void server_context::process_single_task(server_task&& task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
@ -2153,14 +2266,14 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t token_count = slot->cache_tokens.size();
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data.at("filename");
|
||||
std::string filepath = task.data.at("filepath");
|
||||
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
|
||||
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
||||
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
||||
|
||||
@ -2175,7 +2288,7 @@ void server_context::process_single_task(server_task&& task) {
|
||||
{ "id_slot", id_slot },
|
||||
{ "filename", filename },
|
||||
{ "n_saved", token_count }, // tokens saved
|
||||
{ "n_written", nwrite }, // bytes written
|
||||
{ "n_written", nwrite + saved }, // bytes written
|
||||
{ "timings", {
|
||||
{ "save_ms", t_save_ms }
|
||||
} }
|
||||
@ -2196,9 +2309,6 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::string filename = task.data.at("filename");
|
||||
@ -2212,10 +2322,9 @@ void server_context::process_single_task(server_task&& task) {
|
||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
if (mctx) {
|
||||
slot->cache_tokens.has_mtmd = true;
|
||||
}
|
||||
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
|
||||
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
@ -2248,14 +2357,13 @@ void server_context::process_single_task(server_task&& task) {
|
||||
queue_tasks.defer(std::move(task));
|
||||
break;
|
||||
}
|
||||
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
slot->cache_tokens.keep_first(0);
|
||||
//slot->cache_tokens.clear();
|
||||
slot->server_cached_prompt.checkpoints.clear();
|
||||
slot->server_cached_prompt.data.clear();
|
||||
server_task_result result;
|
||||
result.id = task.id;
|
||||
result.stop = true;
|
||||
|
||||
@ -355,6 +355,22 @@ struct server_prompt_checkpoint {
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
json to_json() {
|
||||
json j;
|
||||
j["pos_min"] = pos_min;
|
||||
j["pos_max"] = pos_max;
|
||||
j["pos_min_prompt"] = pos_min_prompt;
|
||||
j["pos_max_prompt"] = pos_max_prompt;
|
||||
return j;
|
||||
}
|
||||
|
||||
void from_json(const json & j) {
|
||||
pos_min = j.value<llama_pos>("pos_min", 0);
|
||||
pos_max = j.value<llama_pos>("pos_max", 0);
|
||||
pos_min_prompt = j.value<llama_pos>("pos_min_prompt", 0);
|
||||
pos_max_prompt = j.value<llama_pos>("pos_max_prompt", 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -384,6 +400,22 @@ struct server_prompt {
|
||||
checkpoints
|
||||
};
|
||||
}
|
||||
|
||||
json to_json()
|
||||
{
|
||||
json j;
|
||||
j["tokens"] = tokens.to_json();
|
||||
j["n_kept_prompt"] = n_kept_prompt;
|
||||
j["n_discarded_prompt"] = n_discarded_prompt;
|
||||
return j;
|
||||
}
|
||||
|
||||
void from_json(const json & j) {
|
||||
tokens.from_json(j.at("tokens"));
|
||||
n_kept_prompt = j.value<llama_pos>("n_kept_prompt", 0);
|
||||
n_discarded_prompt = j.value<llama_pos>("n_discarded_prompt", 0);
|
||||
n_kept_prompt = j.value<llama_pos>("n_kept_prompt", 0);
|
||||
}
|
||||
};
|
||||
|
||||
struct server_prompt_cache {
|
||||
|
||||
@ -52,6 +52,9 @@
|
||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
#define LLAMA_STATE_SEQ_VERSION 3
|
||||
|
||||
#define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv'
|
||||
#define LLAMA_SERVER_VERSION 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user