Avoid some code duplication

This commit is contained in:
Kawrakow 2026-05-08 13:46:10 +00:00
parent dbe44e939f
commit 43df4192d6

View File

@ -997,6 +997,7 @@ static float prob_avx2(int n, const float * logits, float max_val) {
}
return 1.0f/sumf;
}
#endif
static float prob_scalar(int n, const float * logits, float max_val) {
double sum_exp = 0.0;
for (int i = 0; i < n; ++i) {
@ -1004,7 +1005,6 @@ static float prob_scalar(int n, const float * logits, float max_val) {
}
return (float)(1./sum_exp);
}
#endif
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, float * out_prob) {
GGML_UNUSED(gsmpl);
@ -1036,11 +1036,7 @@ llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, str
}
if (out_prob) {
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) {
sum_exp += exp((double)(logits[i] - max_val));
}
*out_prob = (float)(1.0 / sum_exp);
*out_prob = prob_scalar(n_vocab, logits, max_val);
}
return best_id;