server: support slot save/restore/erase for mtmd tokens and checkpoints (#1584)

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana 2026-04-05 01:41:04 -05:00 committed by GitHub
parent 0147cf4837
commit 5e8bb724ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 351 additions and 16 deletions

View File

@ -1054,6 +1054,18 @@ llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
return image_tokens->n_tokens();
}
mtmd_input_chunk * mtmd_create_input_chunk() {
auto * chunk = new mtmd_input_chunk{
MTMD_INPUT_CHUNK_TYPE_TEXT,
std::vector<llama_token>{},
nullptr,
nullptr
};
return chunk;
}
// test function
mtmd_input_chunks * mtmd_test_create_input_chunks() {
@ -1088,3 +1100,133 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
return chunks;
}
static json mtmd_clip_image_f32_to_json(const clip_image_f32 & clip) {
json j;
j["nx"] = clip.nx;
j["ny"] = clip.ny;
j["buf"] = clip.buf;
return j;
}
static clip_image_f32 * mtmd_clip_image_f32_from_json(const json & j) {
clip_image_f32 * clip = new clip_image_f32;
clip->nx = j["nx"];
clip->ny = j["ny"];
clip->buf = j["buf"].get<std::vector<float>>();
return clip;
}
static json mtmd_clip_image_f32_batch_to_json(const clip_image_f32_batch & batch, bool full = false) {
json j;
j["is_audio"] = batch.is_audio;
j["grid_x"] = batch.grid_x;
j["grid_y"] = batch.grid_y;
if (full) {
std::vector<nlohmann::json> entries;
for (auto & entry : batch.entries) {
entries.push_back(mtmd_clip_image_f32_to_json(*entry));
}
j["entries"] = entries;
}
return j;
}
static clip_image_f32_batch mtmd_clip_image_f32_batch_from_json(const json & j, bool full = false) {
clip_image_f32_batch batch;
if (j.contains("is_audio")) {
batch.is_audio = j["is_audio"];
batch.grid_x = j["grid_x"];
batch.grid_y = j["grid_y"];
if (full) {
auto entries = j["entries"];
if (entries.is_array()) {
for (auto & entry : entries) {
clip_image_f32 * clip = mtmd_clip_image_f32_from_json(entry);
batch.entries.push_back(clip_image_f32_ptr(clip));
}
}
}
}
return batch;
}
static mtmd_audio_tokens mtmd_audio_tokens_from_json(json & j) {
return mtmd_audio_tokens{
j.value<uint32_t>("n_tokens", 0),
mtmd_clip_image_f32_batch_from_json(j.value("batch_f32", json{})),
j.value("id","")
};
}
static mtmd_image_tokens mtmd_image_tokens_from_json(json & j) {
return mtmd_image_tokens{
j.value<uint32_t>("nx", 0),
j.value<uint32_t>("ny", 0),
j.value("use_mrope_pos",false),
mtmd_clip_image_f32_batch_from_json(j.value("batch_f32", json{})),
j.value("id","")
};
}
static json mtmd_audio_tokens_to_json(mtmd_audio_tokens * chunk) {
json j;
if (chunk) {
j["n_tokens"] = chunk->n_tokens;
j["id"] = chunk->id;
j["batch_f32"] = mtmd_clip_image_f32_batch_to_json(chunk->batch_f32);
}
return j;
}
static json mtmd_image_tokens_to_json(mtmd_image_tokens * chunk) {
json j;
if (chunk) {
j["nx"] = chunk->nx;
j["ny"] = chunk->ny;
j["use_mrope_pos"] = chunk->use_mrope_pos;
j["batch_f32"] = mtmd_clip_image_f32_batch_to_json(chunk->batch_f32);
j["id"] = chunk->id;
}
return j;
}
mtmd_input_chunk * mtmd_input_chunk_from_json(json & j) {
mtmd_input_chunk * chunk = mtmd_create_input_chunk();
chunk->type = j.value("type", MTMD_INPUT_CHUNK_TYPE_TEXT);
chunk->tokens_text = j.value("tokens_text", chunk->tokens_text);
chunk->tokens_image = nullptr;
chunk->tokens_audio = nullptr;
if (j.contains("tokens_image")) {
chunk->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
auto image_json = j.value("tokens_image", json::array());
*chunk->tokens_image = mtmd_image_tokens_from_json(image_json);
}
if (j.contains("tokens_audio")) {
chunk->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
*chunk->tokens_audio = mtmd_audio_tokens_from_json(j.at("tokens_audio"));
}
return chunk;
}
void mtmd_input_chunk_to_json(mtmd_input_chunk * chunk, json & j) {
j.clear();
if (chunk) {
j["type"] = chunk->type;
j["tokens_text"] = chunk->tokens_text;
if (chunk->tokens_image) {
j["tokens_image"] = mtmd_image_tokens_to_json(chunk->tokens_image.get());
}
if (chunk->tokens_audio) {
j["tokens_audio"] = mtmd_audio_tokens_to_json(chunk->tokens_audio.get());
}
}
}

View File

@ -13,6 +13,8 @@
#include <vector>
#include <cinttypes>
#include <memory>
#include <nlohmann/json.hpp>
using json = nlohmann::ordered_json;
#endif
/**
@ -215,6 +217,9 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
// the reading size (in bytes) is equal to:
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
MTMD_API mtmd_input_chunk * mtmd_create_input_chunk(void);
MTMD_API mtmd_input_chunk * mtmd_input_chunk_from_json(json & j);
MTMD_API void mtmd_input_chunk_to_json(mtmd_input_chunk * chunk, json & j);
/////////////////////////////////////////

View File

@ -2164,6 +2164,47 @@ server_tokens server_tokens::clone() const {
return res;
}
json server_tokens::to_json() const
{
json j;
std::vector<nlohmann::json> media_array;
for (auto & [idx, chunk_ptr] : map_idx_to_media) { // or direct access if friend
if (chunk_ptr) {
nlohmann::json obj;
obj["index"] = idx;
json j;
mtmd_input_chunk_to_json(chunk_ptr.get(), j);
obj["chunk"] = j;
media_array.push_back(std::move(obj));
}
}
j = nlohmann::json{
{"has_mtmd", has_mtmd},
{"map_idx_to_media", media_array},
{"tokens", tokens}
};
return j;
}
void server_tokens::from_json(const json & j) {
clear();
map_idx_to_media.clear();
has_mtmd = j.value("has_mtmd", has_mtmd);
tokens = j.value("tokens", tokens);
map_idx_to_media.clear();
json media_array = j.at("map_idx_to_media");
if (media_array.is_array()) {
for (const auto & entry : media_array) {
size_t idx = entry.at("index");
json chunk_json = entry.at("chunk");
mtmd_input_chunk * chunk = mtmd_input_chunk_from_json(chunk_json);
map_idx_to_media[idx] = mtmd::input_chunk_ptr(chunk);
}
}
}
// Keep the first n_keep and remove n_discard tokens from tokens
void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) {

View File

@ -349,6 +349,10 @@ public:
server_tokens(const llama_tokens& tokens, bool has_mtmd);
json to_json() const;
void from_json(const json & j);
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
llama_pos pos_next(int64_t n_tokens = -1) const;

View File

@ -11,6 +11,8 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include <fstream>
#include <iostream>
#include <regex>
static void log_text(const gpt_params & params_base, const std::string & text) {
@ -1995,6 +1997,117 @@ void server_context::split_multiprompt_task(int id_multi, server_task& multiprom
}
}
static size_t save_checkpoints_to_file(const std::string & filename, const std::list<server_prompt_checkpoint> & checkpoints) {
if (checkpoints.size() == 0) {
return 0;
}
std::ofstream file(filename, std::ios::binary);
uint32_t magic = LLAMA_STATE_SEQ_MAGIC;
file.write(reinterpret_cast<const char *>(&magic), sizeof(magic));
uint32_t version = LLAMA_STATE_SEQ_VERSION;
file.write(reinterpret_cast<const char *>(&version), sizeof(version));
size_t count = checkpoints.size();
file.write(reinterpret_cast<const char *>(&count), sizeof(count));
for (const auto & checkpoint : checkpoints) {
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
file.write(reinterpret_cast<const char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
size_t data_len = checkpoint.data.size();
file.write(reinterpret_cast<const char *>(&data_len), sizeof(data_len));
if (data_len > 0) {
file.write(reinterpret_cast<const char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
}
}
size_t pos = file.tellp();
file.close();
return pos;
}
static size_t load_checkpoints_from_file(const std::string & filename, std::list<server_prompt_checkpoint> & checkpoints) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
return 0;
}
checkpoints.clear();
// version checks
{
uint32_t magic;
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));
uint32_t version;
file.read(reinterpret_cast<char *>(&version), sizeof(version));
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for checkpoint file: %08x, %08x\n", __func__, magic, version);
return 0;
}
}
// load the checkpoints
{
size_t count;
file.read(reinterpret_cast<char *>(&count), sizeof(count));
for (int i = 0; i < count; i++) {
server_prompt_checkpoint checkpoint;
file.read(reinterpret_cast<char *>(&checkpoint.pos_min), sizeof(checkpoint.pos_min));
file.read(reinterpret_cast<char *>(&checkpoint.pos_max), sizeof(checkpoint.pos_max));
file.read(reinterpret_cast<char *>(&checkpoint.pos_min_prompt), sizeof(checkpoint.pos_min_prompt));
file.read(reinterpret_cast<char *>(&checkpoint.pos_max_prompt), sizeof(checkpoint.pos_max_prompt));
size_t data_len;
file.read(reinterpret_cast<char *>(&data_len), sizeof(data_len));
if (data_len > 0) {
checkpoint.data.resize(data_len);
file.read(reinterpret_cast<char *>(checkpoint.data.data()), data_len * sizeof(uint8_t));
}
checkpoints.push_back(checkpoint);
}
}
size_t pos = file.tellg();
file.close();
return pos;
}
static size_t save_server_tokens_to_file(const std::string & filename, const server_tokens & tokens) {
std::ofstream file(filename, std::ios::binary);
json token_json = tokens.to_json();
token_json["magic"] = LLAMA_SERVER_MAGIC;
token_json["version"] = LLAMA_SERVER_VERSION;
size_t pos = 0;
if (file.is_open()) {
file << token_json;
pos = file.tellp();
file.close();
}
return pos;
}
static size_t load_server_tokens_from_file(const std::string & filename, server_tokens & tokens) {
std::ifstream file(filename, std::ios::binary);
if (!file.is_open()) {
return 0;
}
size_t pos = 0;
json token_json;
if (file.is_open()) {
file >> token_json;
pos = file.tellg();
file.close();
}
uint32_t magic = token_json.value<uint32_t>("magic", 0);
uint32_t version = token_json.value<uint32_t>("version", 0);
if (magic != LLAMA_SERVER_MAGIC || version != LLAMA_SERVER_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for token file: %08x, %08x\n", __func__, magic, version);
return 0;
}
tokens.from_json(token_json);
return pos;
}
void server_context::process_single_task(server_task&& task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
@ -2153,14 +2266,14 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
std::string filepath = task.data.at("filepath");
save_server_tokens_to_file(filepath+".tokens.json", slot->cache_tokens);
size_t saved = save_checkpoints_to_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
@ -2175,7 +2288,7 @@ void server_context::process_single_task(server_task&& task) {
{ "id_slot", id_slot },
{ "filename", filename },
{ "n_saved", token_count }, // tokens saved
{ "n_written", nwrite }, // bytes written
{ "n_written", nwrite + saved }, // bytes written
{ "timings", {
{ "save_ms", t_save_ms }
} }
@ -2196,9 +2309,6 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
const int64_t t_start = ggml_time_us();
std::string filename = task.data.at("filename");
@ -2212,10 +2322,9 @@ void server_context::process_single_task(server_task&& task) {
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
slot->cache_tokens.resize(token_count);
if (mctx) {
slot->cache_tokens.has_mtmd = true;
}
load_server_tokens_from_file(filepath+".tokens.json", slot->cache_tokens);
size_t loaded = load_checkpoints_from_file(filepath + ".checkpoints", slot->server_cached_prompt.checkpoints);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
@ -2248,14 +2357,13 @@ void server_context::process_single_task(server_task&& task) {
queue_tasks.defer(std::move(task));
break;
}
if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) {
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
slot->cache_tokens.clear();
slot->cache_tokens.keep_first(0);
//slot->cache_tokens.clear();
slot->server_cached_prompt.checkpoints.clear();
slot->server_cached_prompt.data.clear();
server_task_result result;
result.id = task.id;
result.stop = true;

View File

@ -355,6 +355,22 @@ struct server_prompt_checkpoint {
size_t size() const {
return data.size();
}
json to_json() {
json j;
j["pos_min"] = pos_min;
j["pos_max"] = pos_max;
j["pos_min_prompt"] = pos_min_prompt;
j["pos_max_prompt"] = pos_max_prompt;
return j;
}
void from_json(const json & j) {
pos_min = j.value<llama_pos>("pos_min", 0);
pos_max = j.value<llama_pos>("pos_max", 0);
pos_min_prompt = j.value<llama_pos>("pos_min_prompt", 0);
pos_max_prompt = j.value<llama_pos>("pos_max_prompt", 0);
}
};
@ -384,6 +400,22 @@ struct server_prompt {
checkpoints
};
}
json to_json()
{
json j;
j["tokens"] = tokens.to_json();
j["n_kept_prompt"] = n_kept_prompt;
j["n_discarded_prompt"] = n_discarded_prompt;
return j;
}
void from_json(const json & j) {
tokens.from_json(j.at("tokens"));
n_kept_prompt = j.value<llama_pos>("n_kept_prompt", 0);
n_discarded_prompt = j.value<llama_pos>("n_discarded_prompt", 0);
n_kept_prompt = j.value<llama_pos>("n_kept_prompt", 0);
}
};
struct server_prompt_cache {

View File

@ -52,6 +52,9 @@
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 3
#define LLAMA_SERVER_MAGIC 0x6c6d7376u // 'lmsv'
#define LLAMA_SERVER_VERSION 1
#ifdef __cplusplus
extern "C" {
#endif