diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 71c6cdd01d..0b220af6fd 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1255,20 +1255,16 @@ struct clip_model_loader { } if (is_vision) { - int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN); - int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD); - GGML_ASSERT(idx_mean >= 0 && "image_mean not found"); - GGML_ASSERT(idx_std >= 0 && "image_std not found"); - const size_t n_mean = gguf_get_arr_n(ctx_gguf.get(), idx_mean); - const size_t n_std = gguf_get_arr_n(ctx_gguf.get(), idx_std); - if (n_mean < 3 || n_std < 3) { - throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, n_mean, n_std)); + std::vector image_mean; + std::vector image_std; + get_arr_f32(KEY_IMAGE_MEAN, image_mean, false); + get_arr_f32(KEY_IMAGE_STD , image_std, false); + if (image_mean.size() < 3 || image_std.size() < 3) { + throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, image_mean.size(), image_std.size())); } - const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean); - const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std); for (int i = 0; i < 3; ++i) { - hparams.image_mean[i] = mean_data[i]; - hparams.image_std[i] = std_data[i]; + hparams.image_mean[i] = image_mean[i]; + hparams.image_std[i] = image_std[i]; } } @@ -3078,6 +3074,29 @@ struct clip_model_loader { output = gguf_get_val_f32(ctx_gguf.get(), i); } + void get_arr_f32(const std::string & key, std::vector & output, bool required = true) const { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) { + throw std::runtime_error("Key not found: " + key); + } + return; + } + const auto type = gguf_get_arr_type(ctx_gguf.get(), i); + if (type != GGUF_TYPE_FLOAT32) { + throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_FLOAT32)\n", __func__, key.c_str(), type, GGUF_TYPE_FLOAT32)); + } + const size_t n = gguf_get_arr_n(ctx_gguf.get(), i); + if (n > (size_t) std::numeric_limits::max()) { + throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n)); + } + output.resize(n); + const float * values = (const float *)gguf_get_arr_data(ctx_gguf.get(), i); + for (size_t j = 0; j < n; ++j) { + output[j] = values[j]; + } + } + void get_string(const std::string & key, std::string & output, bool required = true) const { const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); if (i < 0) {