refactor a bit

This commit is contained in:
Xuan Son Nguyen 2026-06-25 18:49:30 +02:00
parent 732d5b6fd8
commit a4b1c14c1a

View File

@ -1255,20 +1255,16 @@ struct clip_model_loader {
}
if (is_vision) {
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
GGML_ASSERT(idx_std >= 0 && "image_std not found");
const size_t n_mean = gguf_get_arr_n(ctx_gguf.get(), idx_mean);
const size_t n_std = gguf_get_arr_n(ctx_gguf.get(), idx_std);
if (n_mean < 3 || n_std < 3) {
throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, n_mean, n_std));
std::vector<float> image_mean;
std::vector<float> image_std;
get_arr_f32(KEY_IMAGE_MEAN, image_mean, false);
get_arr_f32(KEY_IMAGE_STD , image_std, false);
if (image_mean.size() < 3 || image_std.size() < 3) {
throw std::runtime_error(string_format("%s: image_mean/image_std arrays must have at least 3 elements, got %zu and %zu\n", __func__, image_mean.size(), image_std.size()));
}
const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
for (int i = 0; i < 3; ++i) {
hparams.image_mean[i] = mean_data[i];
hparams.image_std[i] = std_data[i];
hparams.image_mean[i] = image_mean[i];
hparams.image_std[i] = image_std[i];
}
}
@ -3078,6 +3074,29 @@ struct clip_model_loader {
output = gguf_get_val_f32(ctx_gguf.get(), i);
}
void get_arr_f32(const std::string & key, std::vector<float> & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {
if (required) {
throw std::runtime_error("Key not found: " + key);
}
return;
}
const auto type = gguf_get_arr_type(ctx_gguf.get(), i);
if (type != GGUF_TYPE_FLOAT32) {
throw std::runtime_error(string_format("%s: array '%s' has type %d, expected %d (GGUF_TYPE_FLOAT32)\n", __func__, key.c_str(), type, GGUF_TYPE_FLOAT32));
}
const size_t n = gguf_get_arr_n(ctx_gguf.get(), i);
if (n > (size_t) std::numeric_limits<int>::max()) {
throw std::runtime_error(string_format("%s: array '%s' is too large (%zu elements)\n", __func__, key.c_str(), n));
}
output.resize(n);
const float * values = (const float *)gguf_get_arr_data(ctx_gguf.get(), i);
for (size_t j = 0; j < n; ++j) {
output[j] = values[j];
}
}
void get_string(const std::string & key, std::string & output, bool required = true) const {
const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
if (i < 0) {