mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
common : support manually triggering the reasoning budget end sequence (#23949)
This commit is contained in:
parent
e22b0de60d
commit
5254a7994d
@ -247,3 +247,24 @@ common_reasoning_budget_state common_reasoning_budget_get_state(const struct lla
|
||||
}
|
||||
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
|
||||
}
|
||||
|
||||
bool common_reasoning_budget_force(struct llama_sampler * smpl) {
|
||||
if (!smpl) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
// only a sampler that is actively counting down the budget may be forced;
|
||||
// any other state (idle, already forcing/waiting, or done) is left untouched
|
||||
if (ctx->state != REASONING_BUDGET_COUNTING) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -40,3 +40,7 @@ struct llama_sampler * common_reasoning_budget_init(
|
||||
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
||||
// Manually transition the reasoning budget sampler into the FORCING state.
|
||||
// Returns true if the transition occurred.
|
||||
bool common_reasoning_budget_force(struct llama_sampler * smpl);
|
||||
|
||||
@ -661,6 +661,14 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
return llama_sampler_get_seed(gsmpl->chain);
|
||||
}
|
||||
|
||||
bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl) {
|
||||
if (!gsmpl) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return common_reasoning_budget_force(gsmpl->rbudget);
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
|
||||
@ -87,6 +87,9 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
// force the reasoning budget sampler (if any) to begin forcing its end sequence now.
|
||||
bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl);
|
||||
|
||||
// helpers
|
||||
|
||||
// access the internal list of current candidate tokens
|
||||
|
||||
@ -184,6 +184,76 @@ static void test_reasoning_budget_clone_mid_forcing() {
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
static void test_reasoning_budget_force_manual() {
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
|
||||
// if COUNTING, force() succeeds and begins forcing the end sequence from the start
|
||||
{
|
||||
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE);
|
||||
|
||||
llama_sampler_accept(sampler, 100); // COUNTING, remaining=5
|
||||
llama_sampler_accept(sampler, 50); // COUNTING, remaining=4
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_COUNTING);
|
||||
|
||||
GGML_ASSERT(common_reasoning_budget_force(sampler) && "force() should succeed from COUNTING");
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_FORCING);
|
||||
|
||||
// forces the configured sequence from force_pos=0, then transitions to DONE
|
||||
GGML_ASSERT(get_forced_token(sampler, 102) == 102);
|
||||
llama_sampler_accept(sampler, 102);
|
||||
GGML_ASSERT(get_forced_token(sampler, 102) == 101);
|
||||
llama_sampler_accept(sampler, 101);
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE);
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
// if IDLE, force() is a no-op
|
||||
{
|
||||
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE);
|
||||
|
||||
GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from IDLE");
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_IDLE);
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
// if DONE, force() is a no-op
|
||||
{
|
||||
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE);
|
||||
|
||||
llama_sampler_accept(sampler, 100); // COUNTING
|
||||
llama_sampler_accept(sampler, 101); // natural end -> DONE
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE);
|
||||
|
||||
GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from DONE");
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE);
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
// if FORCING, force() is a no-op and must not rewind the force position
|
||||
{
|
||||
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING);
|
||||
|
||||
GGML_ASSERT(get_forced_token(sampler, 102) == 102);
|
||||
llama_sampler_accept(sampler, 102); // advance to the second forced token (force_pos=1)
|
||||
|
||||
GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from FORCING");
|
||||
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_FORCING);
|
||||
GGML_ASSERT(get_forced_token(sampler, 102) == 101 && "force() must not rewind the force position");
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
// a null sampler is safely ignored
|
||||
GGML_ASSERT(!common_reasoning_budget_force(nullptr));
|
||||
|
||||
fprintf(stderr, " Test 'manual force transition' passed\n");
|
||||
}
|
||||
|
||||
// UTF-8 boundary detection unit test
|
||||
// Tests common_utf8_is_complete() from reasoning-budget.h
|
||||
static void test_utf8_boundary_detection() {
|
||||
@ -312,8 +382,9 @@ int main(void) {
|
||||
|
||||
test_reasoning_budget_clone_mid_counting();
|
||||
test_reasoning_budget_clone_mid_forcing();
|
||||
test_reasoning_budget_force_manual();
|
||||
|
||||
printf("OK (8 tests passed)\n");
|
||||
printf("OK (9 tests passed)\n");
|
||||
|
||||
printf("Testing UTF-8 boundary detection... ");
|
||||
test_utf8_boundary_detection();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user