sycl : clamp softmax input to avoid underflow (#24941)

This commit is contained in:
Jassieluo 2026-06-26 15:02:42 +08:00 committed by GitHub
parent b11f7c16bc
commit e7e3f35090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -126,7 +126,7 @@ static void soft_max_f32(const float * x,
break; break;
} }
const float val = sycl::native::exp(vals[col] - max_val); const float val = sycl::native::exp(sycl::max(vals[col] - max_val, -80.0f));
tmp += val; tmp += val;
vals[col] = val; vals[col] = val;
} }
@ -154,7 +154,7 @@ static void soft_max_f32(const float * x,
tmp = warp_reduce_sum<WARP_SIZE>(tmp); tmp = warp_reduce_sum<WARP_SIZE>(tmp);
} }
if (sinks) { if (sinks) {
tmp += sycl::native::exp(sinks[i02] - max_val); tmp += sycl::native::exp(sycl::max(sinks[i02] - max_val, -80.0f));
} }
const float inv_sum = 1.0f / tmp; const float inv_sum = 1.0f / tmp;