mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-06-28 04:30:15 -05:00
Avoid some code duplication
This commit is contained in:
parent
dbe44e939f
commit
43df4192d6
@ -997,6 +997,7 @@ static float prob_avx2(int n, const float * logits, float max_val) {
|
||||
}
|
||||
return 1.0f/sumf;
|
||||
}
|
||||
#endif
|
||||
static float prob_scalar(int n, const float * logits, float max_val) {
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
@ -1004,7 +1005,6 @@ static float prob_scalar(int n, const float * logits, float max_val) {
|
||||
}
|
||||
return (float)(1./sum_exp);
|
||||
}
|
||||
#endif
|
||||
|
||||
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, float * out_prob) {
|
||||
GGML_UNUSED(gsmpl);
|
||||
@ -1036,11 +1036,7 @@ llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, str
|
||||
}
|
||||
|
||||
if (out_prob) {
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sum_exp += exp((double)(logits[i] - max_val));
|
||||
}
|
||||
*out_prob = (float)(1.0 / sum_exp);
|
||||
*out_prob = prob_scalar(n_vocab, logits, max_val);
|
||||
}
|
||||
|
||||
return best_id;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user