mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
sycl : clamp softmax input to avoid underflow (#24941)
This commit is contained in:
parent
b11f7c16bc
commit
e7e3f35090
@ -126,7 +126,7 @@ static void soft_max_f32(const float * x,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float val = sycl::native::exp(vals[col] - max_val);
|
const float val = sycl::native::exp(sycl::max(vals[col] - max_val, -80.0f));
|
||||||
tmp += val;
|
tmp += val;
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
}
|
}
|
||||||
@ -154,7 +154,7 @@ static void soft_max_f32(const float * x,
|
|||||||
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
|
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
|
||||||
}
|
}
|
||||||
if (sinks) {
|
if (sinks) {
|
||||||
tmp += sycl::native::exp(sinks[i02] - max_val);
|
tmp += sycl::native::exp(sycl::max(sinks[i02] - max_val, -80.0f));
|
||||||
}
|
}
|
||||||
const float inv_sum = 1.0f / tmp;
|
const float inv_sum = 1.0f / tmp;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user