mtmd: add more validations (#25013)

* mtmd: add more validations

* fix

* refactor a bit

* type check for get_arr_int
This commit is contained in:
Xuan-Son Nguyen 2026-06-26 08:43:29 +02:00 committed by GitHub
parent f818065d75
commit b11f7c16bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 25 deletions

View File

@ -55,8 +55,7 @@ struct clip_hparams {
int32_t n_head = 0;
int32_t n_head_kv = 0;
int32_t n_layer = 0;
// idefics3
int32_t n_merge = 0; // number of patch merges **per-side**
int32_t n_merge = 1; // number of patch merges **per-side**
// for preprocessor
int32_t image_longest_edge = 0;
@ -135,8 +134,7 @@ struct clip_hparams {
int32_t custom_image_max_tokens = -1;
void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
const int cur_merge = n_merge == 0 ? 1 : n_merge;
const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
const int patch_area = patch_size * patch_size * n_merge * n_merge;
image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
@ -145,8 +143,7 @@ struct clip_hparams {
void set_warmup_n_tokens(int n_tokens) {
int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
const int cur_merge = n_merge == 0 ? 1 : n_merge;
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
warmup_image_size = n_tok_per_side * patch_size * n_merge;
// TODO: support warmup size for custom token numbers
}
// sam vit deepseek-ocr

View File

@ -1210,6 +1210,9 @@ struct clip_model_loader {
{
std::vector<int> pinpoints;
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, pinpoints, false);
if (pinpoints.size() % 2 != 0) {
throw std::runtime_error(string_format("%s: image_grid_pinpoints must have an even number of elements, got %zu\n", __func__, pinpoints.size()));
}
if (!pinpoints.empty()) {
for (size_t i = 0; i < pinpoints.size(); i += 2) {
hparams.image_res_candidates.push_back({
@ -1252,15 +1255,16 @@ struct clip_model_loader {
}
if (is_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
GGML_ASSERT(idx_std >= 0 && "image_std not found");
const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
std::vector<float> image_mean;
std::vector<float> image_std;
get_arr_f32(KEY_IMAGE_MEAN, image_mean, false);
get_arr_f32(KEY_IMAGE_STD , image_std, false);
if (image_mean.size() < 3 || image_std.size() < 3) {
throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, image_mean.size(), image_std.size()));
}
for (int i = 0; i < 3; ++i) {
hparams.image_mean[i] = mean_data[i];
hparams.image_std[i] = std_data[i];
hparams.image_mean[i] = image_mean[i];
hparams.image_std[i] = image_std[i];
}
}
@ -1686,8 +1690,8 @@ struct clip_model_loader {
if (hparams.image_size > 65536) {
throw std::runtime_error(string_format("%s: image_size (%d) is too large (max 65536)\n", __func__, hparams.image_size));
}
if (hparams.patch_size <= 0) {
throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size));
if (hparams.patch_size <= 0 || hparams.patch_size >= 65536) {
throw std::runtime_error(string_format("%s: patch_size (%d) must be positive and less than 65536\n", __func__, hparams.patch_size));
}
if (hparams.n_embd <= 0) {
throw std::runtime_error(string_format("%s: n_embd (%d) must be greater than 0\n", __func__, hparams.n_embd));
@ -1695,6 +1699,9 @@ struct clip_model_loader {
if (hparams.image_max_pixels < hparams.image_min_pixels) {
throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
}
if (hparams.n_merge < 0 || hparams.n_merge >= 65536) {
throw std::runtime_error(string_format("%s: n_merge (%d) must be greater than 0 and less than 65536\n", __func__, hparams.n_merge));
}
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
@ -3067,6 +3074,29 @@ struct clip_model_loader {
output = gguf_get_val_f32(ctx_gguf.get(), i);
}
void get_arr_f32(const std::string & key, std::vector<float> & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
if (type != GGUF_TYPE_FLOAT32) {
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_FLOAT32)\n", __func__, key.c_str(), type, GGUF_TYPE_FLOAT32));
}
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
if (n > (size_t) std::numeric_limits<int>::max()) {
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
}
output.resize(n);
const float * values = (const float *)gguf_get_arr_data(ctx_gguf.get(), i);
for (size_t j = 0; j < n; ++j) {
output[j] = values[j];
}
}
void get_string(const std::string & key, std::string & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
@ -3086,11 +3116,18 @@ struct clip_model_loader {
}
return;
}
int n = gguf_get_arr_n(ctx_gguf.get(), i);
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
if (type != GGUF_TYPE_INT32) {
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_INT32)\n", __func__, key.c_str(), type, GGUF_TYPE_INT32));
}
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
if (n > (size_t) std::numeric_limits<int>::max()) {
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
}
output.resize(n);
const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i);
for (int i = 0; i < n; ++i) {
output[i] = values[i];
for (size_t j = 0; j < n; ++j) {
output[j] = values[j];
}
}
@ -3364,8 +3401,8 @@ int clip_n_output_tokens(const clip_ctx * ctx, const clip_image_f32 * img) {
{
// dynamic size
int n_merge = ctx->model.hparams.n_merge;
int n_patches_x = img->nx() / patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny() / patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_x = img->nx() / patch_size / n_merge;
int n_patches_y = img->ny() / patch_size / n_merge;
if (ctx->model.token_embd_img_break) {
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else {

View File

@ -63,8 +63,8 @@ ggml_cgraph * clip_graph_pixtral::build() {
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
const int p_y = n_patches_y / n_merge;
const int p_x = n_patches_x / n_merge;
const int p_total = p_x * p_y;
const int n_embd_text = cur->ne[0];
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row

View File

@ -628,7 +628,7 @@ mtmd_image_preproc_out mtmd_image_preprocessor_llava_uhd::preprocess(const clip_
mtmd_image_preprocessor_llava_uhd::slice_instructions mtmd_image_preprocessor_llava_uhd::get_slice_instructions(const clip_image_size & original_size) {
mtmd_image_preprocessor_llava_uhd::slice_instructions res;
// align slices by patch_size * n_merge so an integer number of merger output tokens fits per slice
const int n_merge = hparams.n_merge > 0 ? hparams.n_merge : 1;
const int n_merge = hparams.n_merge;
const int patch_size = hparams.patch_size * n_merge;
const int slice_size = hparams.image_size;
const int original_width = original_size.width;
@ -894,7 +894,7 @@ mtmd_image_preproc_out mtmd_image_preprocessor_dyn_size::preprocess(const clip_i
clip_image_u8 resized_image;
const clip_image_size original_size = img.get_size();
// the original pixtral model doesn't have n_merge
const int cur_merge = hparams.n_merge == 0 ? 1 : hparams.n_merge;
const int cur_merge = hparams.n_merge;
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
original_size,
hparams.patch_size * cur_merge,