diff --git a/common/common.h b/common/common.h index b6773d60..95dfb6c7 100644 --- a/common/common.h +++ b/common/common.h @@ -420,7 +420,7 @@ struct gpt_params { float slot_prompt_similarity = 0.1f; bool do_checkpoint = false; // do checkpoint for recurrent models only - int32_t ctx_checkpoints_n = 8; // max number of context checkpoints per slot + int32_t ctx_checkpoints_n = 32; // max number of context checkpoints per slot int32_t ctx_checkpoints_interval = 512; // minimum number of tokens between each context checkpoints int32_t ctx_checkpoints_tolerance = 5; // the number of tokens before the full prompt to create the checkpoint int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp index e667a209..c3138f45 100644 --- a/common/regex-partial.cpp +++ b/common/regex-partial.cpp @@ -101,7 +101,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { sequence->back() += *it; auto is_star = *it == '*'; ++it; - if (is_star) { + if (it != end && is_star) { if (*it == '?') { ++it; } diff --git a/examples/server/server-common.cpp b/examples/server/server-common.cpp index db8c9f36..8c3e8b23 100644 --- a/examples/server/server-common.cpp +++ b/examples/server/server-common.cpp @@ -1791,263 +1791,472 @@ token_probabilities get_token_probabilities(llama_context* ctx, int idx, llama_t */ server_tokens::server_tokens(mtmd::input_chunks& mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); - } + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); } +} -server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { - } +server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} - llama_pos server_tokens::pos_next() const { - if (!has_mtmd) { +llama_pos server_tokens::pos_next(int64_t n_tokens) const { + if (!has_mtmd) { + if (n_tokens < 0) { return tokens.size(); } + return n_tokens; + } + + if (n_tokens < 0) { llama_pos res = tokens.size(); for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto& chunk = it->second; + const auto & chunk = it->second; res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); } return res; } - // for debugging - std::string server_tokens::str() const { - std::ostringstream oss; - oss << "tokens: "; - for (size_t idx = 0; idx < tokens.size(); ++idx) { - llama_token t = tokens[idx]; - oss << "idx:" << idx << " "; - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } - else { - oss << t << " "; - } + int64_t idx = 0; + llama_pos pos = 0; + + GGML_ASSERT(n_tokens <= (int64_t)tokens.size()); + + while (idx < n_tokens) { + const auto media_it = map_idx_to_media.find(idx); + if (media_it != map_idx_to_media.end()) { + const auto & chunk = media_it->second; + const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get()); + + pos += n_pos; + idx += n_tok; + } else { + pos++; + idx++; } - oss << "\n"; - oss << "image idx: "; - for (const auto& it : map_idx_to_media) { - oss << it.first << ", "; - } - return oss.str(); } - const mtmd::input_chunk_ptr& server_tokens::find_chunk(size_t idx) const { - auto it = map_idx_to_media.find(idx); - if (it != map_idx_to_media.end()) { - return it->second; - } - throw std::runtime_error("Chunk not found"); + return pos; +} + + +size_t server_tokens::size_up_to_pos(llama_pos max_idx) const { + if (!has_mtmd) { + return std::min((size_t)max_idx+1, tokens.size()); } - void server_tokens::push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); + size_t idx = 0; + llama_pos pos = 0; + + while (idx < tokens.size()) { + const auto media_it = map_idx_to_media.find(idx); + if (media_it != map_idx_to_media.end()) { + const auto & chunk = media_it->second; + const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get()); + + pos += n_pos; + idx += n_tok; + } else { + pos++; + idx++; + } + + if (idx >= max_idx) { + break; } - tokens.emplace_back(tok); } - // will create a copy of the chunk if it contains non-text data - void server_tokens::push_back(const mtmd_input_chunk* chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); - size_t start_idx = tokens.size(); - for (size_t i = 0; i < n_tokens; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx] = std::move(new_chunk); - } - else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - const auto* text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } + return idx+1; +} + + +// for debugging +std::string server_tokens::str() const { + std::ostringstream oss; + oss << "tokens: "; + for (size_t idx = 0; idx < tokens.size(); ++idx) { + llama_token t = tokens[idx]; + oss << "idx:" << idx << " "; + if (t == LLAMA_TOKEN_NULL) { + oss << " "; } else { - GGML_ABORT("Invalid chunk type"); + oss << t << " "; } } + oss << "\n"; + oss << "image idx: "; + for (const auto& it : map_idx_to_media) { + oss << it.first << ", "; + } + return oss.str(); +} - // appends server tokens, updates the media map. copies media chunks. - void server_tokens::push_back(server_tokens& tokens) { - size_t start_idx = size(); - for (size_t i = 0; i < tokens.size(); i++) { - push_back(tokens[i]); +const mtmd::input_chunk_ptr& server_tokens::find_chunk(size_t idx) const { + auto it = map_idx_to_media.find(idx); + if (it != map_idx_to_media.end()) { + return it->second; + } + throw std::runtime_error("Chunk not found"); +} + +void server_tokens::push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); +} + +// will create a copy of the chunk if it contains non-text data +void server_tokens::push_back(const mtmd_input_chunk* chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { + GGML_ASSERT(has_mtmd); + const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); + size_t start_idx = tokens.size(); + for (size_t i = 0; i < n_tokens; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); } - if (tokens.has_mtmd) { - // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. - // We could also just check, but this will prevent silently dropping MTMD data. - GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { - auto* chunk = tokens.map_idx_to_media[it->first].get(); - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx + it->first] = std::move(new_chunk); + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx] = std::move(new_chunk); + } + else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto* text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } + else { + GGML_ABORT("Invalid chunk type"); + } +} + +// appends server tokens, updates the media map. copies media chunks. +void server_tokens::push_back(server_tokens& tokens) { + size_t start_idx = size(); + for (size_t i = 0; i < tokens.size(); i++) { + push_back(tokens[i]); + } + if (tokens.has_mtmd) { + // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. + // We could also just check, but this will prevent silently dropping MTMD data. + GGML_ASSERT(has_mtmd); + for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { + auto* chunk = tokens.map_idx_to_media[it->first].get(); + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_idx_to_media[start_idx + it->first] = std::move(new_chunk); + } + } +} + +// for compatibility with context shift and prompt truncation +void server_tokens::insert(const std::vector& inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); +} + +// for compatibility with context shift and prompt truncation +void server_tokens::resize(size_t size) { + //GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.resize(size); +} + +llama_token* server_tokens::data() { + return tokens.data(); +} + +llama_tokens::iterator server_tokens::begin() { + return tokens.begin(); +} + +llama_tokens::iterator server_tokens::end() { + return tokens.end(); +} + +llama_tokens::const_iterator server_tokens::cbegin() { + return tokens.cbegin(); +} + +llama_tokens::const_iterator server_tokens::cend() { + return tokens.cend(); +} + +llama_tokens server_tokens::tokens_data() { + return tokens; +} + +// for compatibility with speculative decoding, ctx shift, slot save/load +const std::vector& server_tokens::get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; +} + +// for compatibility with speculative decoding +void server_tokens::set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; +} + +size_t server_tokens::size() const { + return tokens.size(); +} + +bool server_tokens::empty() const { + return tokens.empty(); +} + +void server_tokens::clear() { + tokens.clear(); +} + +void server_tokens::keep_first(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + llama_token last_token = tokens[n - 1]; + // make sure we never remove tokens in the middle of an image + if (last_token == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { + size_t idx = it->first; + if (idx >= n) { + it = map_idx_to_media.erase(it); + } + else { + ++it; } } } + tokens.resize(n); +} - // for compatibility with context shift and prompt truncation - void server_tokens::insert(const std::vector& inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); +std::string server_tokens::detokenize(const llama_context* ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto& t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } } + return common_detokenize(ctx, text_tokens, special); +} - // for compatibility with context shift and prompt truncation - void server_tokens::resize(size_t size) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.resize(size); +std::string server_tokens::detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const { + std::string str; + if (tokens.size() <= start || length == 0) { + return str; } - - llama_token* server_tokens::data() { - return tokens.data(); - } - - llama_tokens::iterator server_tokens::begin() { - return tokens.begin(); - } - - llama_tokens::iterator server_tokens::end() { - return tokens.end(); - } - - llama_tokens::const_iterator server_tokens::cbegin() { - return tokens.cbegin(); - } - - llama_tokens::const_iterator server_tokens::cend() { - return tokens.cend(); - } - - llama_tokens server_tokens::tokens_data() { - return tokens; - } - - // for compatibility with speculative decoding, ctx shift, slot save/load - const std::vector& server_tokens::get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; - } - - // for compatibility with speculative decoding - void server_tokens::set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; - } - - size_t server_tokens::size() const { - return tokens.size(); - } - - bool server_tokens::empty() const { - return tokens.empty(); - } - - void server_tokens::clear() { - tokens.clear(); - } - - void server_tokens::keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do + llama_tokens text_tokens; + text_tokens.reserve(tokens.size() - start); + size_t i = 0; + size_t count = 0; + for (const auto& t : tokens) { + if (t != LLAMA_TOKEN_NULL && i >= start) { + text_tokens.push_back(t); + ++count; + if (count >= length) { + break; } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - llama_token last_token = tokens[n - 1]; - // make sure we never remove tokens in the middle of an image - if (last_token == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + ++i; + } + return common_detokenize(ctx, text_tokens, special); +} + +size_t server_tokens::find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special, + size_t start, const size_t length) { + std::string str = detokenize(ctx, special, start, length); + std::vector tmp; + size_t n = find_n_tokens_from_string(ctx, b.tokens, start, length, tmp); + return n; +} + +size_t server_tokens::get_common_prefix_exact(const server_tokens& b) const { + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); + + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } + return i; + } + return max_idx; + } + + for (size_t i = 0; i < max_idx; ++i) { + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + const auto& a_chunk = find_chunk(i); + const auto& b_chunk = b.find_chunk(i); + + GGML_ASSERT(a_chunk && b_chunk); + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); + const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); + + if (id_ai == id_bi && n_tok_a == n_tok_b) { + GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen + i += n_tok_a - 1; // will be +1 by the for loop + continue; + } + + return i; + } + + if (ai == bi) { + continue; + } + + return i; + } + + return max_idx; // all tokens are equal +} + +server_tokens server_tokens::get_tokens_exclude_think(const llama_context * ctx, const thinking_tokens & think_token) const { + if (!think_token.exclude) { + return clone(); + } + GGML_ASSERT((think_token.begin != "" && think_token.end != "") && "think tokens cannot be empty"); + std::string startStr = think_token.begin; + std::string endStr = think_token.end; + std::string str = detokenize(ctx, true, 0, n_tokens()); + + std::vector> results; + // Find all positions of start and end + std::vector startPositions; + std::vector endPositions; + + size_t pos = 0; + // Find all start positions + while ((pos = str.find(startStr, pos)) != std::string::npos) { + startPositions.push_back(pos); + pos += startStr.length(); + } + + pos = 0; + // Find all end positions + while ((pos = str.find(endStr, pos)) != std::string::npos) { + endPositions.push_back(pos + endStr.length()); + pos += endStr.length(); + } + + // For each start position, pair with all end positions that come after it + for (size_t i = 0; i < startPositions.size(); i++) { + for (size_t j = 0; j < endPositions.size(); j++) { + if (results.size()) { + // start must be after last end + if (startPositions[i] > results[results.size() - 1].second && endPositions[j] > startPositions[i]) { + results.push_back({ startPositions[i], endPositions[j] }); + break; } - } - // remove all image chunks that are not used anymore - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { - size_t idx = it->first; - if (idx >= n) { - it = map_idx_to_media.erase(it); - } - else { - ++it; - } - } - } - tokens.resize(n); - } - - std::string server_tokens::detokenize(const llama_context* ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto& t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); - } - - std::string server_tokens::detokenize(const llama_context* ctx, bool special, size_t start, size_t length) const { - std::string str; - if (tokens.size() <= start || length == 0) { - return str; - } - llama_tokens text_tokens; - text_tokens.reserve(tokens.size() - start); - size_t i = 0; - size_t count = 0; - for (const auto& t : tokens) { - if (t != LLAMA_TOKEN_NULL && i >= start) { - text_tokens.push_back(t); - ++count; - if (count >= length) { + } else { + if (endPositions[j] > startPositions[i]) { + results.push_back({ startPositions[i], endPositions[j] }); break; } } + + } + } + if (!results.size()) { + return clone(); + } + + server_tokens res; + res.has_mtmd = has_mtmd; + // Exclude tokens + pos = 0; + size_t n = 0; + size_t string_len = 0; + auto model = llama_get_model(ctx); + for (n = 0; n < tokens.size(); ++n) { + if (tokens[n] != LLAMA_TOKEN_NULL) { + str = llama_token_to_piece(model, tokens[n], true); + string_len = string_len + str.size(); + } + if (string_len <= results[pos].first) { + res.tokens.push_back(tokens[n]); + auto it = map_idx_to_media.find(n); + if (it!= map_idx_to_media.end()) { + const mtmd::input_chunk_ptr & chunk = it->second; + res.map_idx_to_media[res.n_tokens()-1] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get())); + } + } else if (string_len <= results[pos].second) { + continue; + } else { + res.tokens.push_back(tokens[n]); + auto it = map_idx_to_media.find(n); + if (it != map_idx_to_media.end()) { + const mtmd::input_chunk_ptr & chunk = it->second; + res.map_idx_to_media[res.n_tokens() - 1] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get())); + } + if (pos + 1 < results.size()) { + pos++; + } + } + } + return res; +} + +common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const { + common_prefix token_prefix; + + size_t n = get_common_prefix_exact(b); // strict token match as a starting point + token_prefix.first = n; + token_prefix.second = n; + + if (!has_mtmd) { + token_prefix = find_common_text_token_prefix(ctx, this->tokens, b.tokens, n, exact); + token_prefix.first += n; + token_prefix.second += n; + return token_prefix; + } + size_t i = n; + size_t j = n; + llama_tokens a_list; + llama_tokens b_list; + while (i < size() && j < b.size()) { + llama_token ai = tokens[i]; + llama_token bi = b.tokens[j]; + if (ai != LLAMA_TOKEN_NULL) { + a_list.push_back(ai); ++i; } - return common_detokenize(ctx, text_tokens, special); - } - - size_t server_tokens::find_n_from_tokens(const llama_context* ctx, const server_tokens& b, bool special, - size_t start, const size_t length) { - std::string str = detokenize(ctx, special, start, length); - std::vector tmp; - size_t n = find_n_tokens_from_string(ctx, b.tokens, start, length, tmp); - return n; - } - - size_t server_tokens::get_common_prefix_exact(const server_tokens& b) const { - const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - - if (!has_mtmd) { - for (size_t i = 0; i < max_idx; ++i) { - if (tokens[i] == b.tokens[i]) { - continue; - } - return i; - } - return max_idx; + if (bi != LLAMA_TOKEN_NULL) { + b_list.push_back(bi); + ++j; } - - for (size_t i = 0; i < max_idx; ++i) { - const llama_token ai = tokens[i]; - const llama_token bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact); + // text match or empty + if (prefix.first == a_list.size() && prefix.second == b_list.size()) { + a_list.clear(); + b_list.clear(); const auto& a_chunk = find_chunk(i); - const auto& b_chunk = b.find_chunk(i); + const auto& b_chunk = b.find_chunk(j); GGML_ASSERT(a_chunk && b_chunk); @@ -2057,302 +2266,173 @@ server_tokens::server_tokens(const llama_tokens& tokens, bool has_mtmd) : has_mt const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); + // image match if (id_ai == id_bi && n_tok_a == n_tok_b) { GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen - i += n_tok_a - 1; // will be +1 by the for loop - continue; - } - - return i; - } - - if (ai == bi) { - continue; - } - - return i; - } - - return max_idx; // all tokens are equal - } - - llama_tokens server_tokens::get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const { - if (!think_token.exclude) { - return get_text_tokens(); - } - GGML_ASSERT((think_token.begin != "" && think_token.end != "") && "think tokens cannot be empty"); - std::string startStr = think_token.begin; - std::string endStr = think_token.end; - - llama_tokens tokens = get_text_tokens(); - std::string str = common_detokenize(ctx, tokens, true); - - std::vector> results; - // Find all positions of start and end - std::vector startPositions; - std::vector endPositions; - - size_t pos = 0; - // Find all start positions - while ((pos = str.find(startStr, pos)) != std::string::npos) { - startPositions.push_back(pos); - pos += startStr.length(); - } - - pos = 0; - // Find all end positions - while ((pos = str.find(endStr, pos)) != std::string::npos) { - endPositions.push_back(pos + endStr.length()); - pos += endStr.length(); - } - - // For each start position, pair with all end positions that come after it - for (size_t i = 0; i < startPositions.size(); i++) { - for (size_t j = 0; j < endPositions.size(); j++) { - if (results.size()) { - // start must be after last end - if (startPositions[i] > results[results.size() - 1].second && endPositions[j] > startPositions[i]) { - results.push_back({ startPositions[i], endPositions[j] }); - break; - } + i += n_tok_a; + j += n_tok_a; + prefix.first += n_tok_a; + prefix.second += n_tok_a; + token_prefix = common_prefix_add(prefix, token_prefix); } else { - if (endPositions[j] > startPositions[i]) { - results.push_back({ startPositions[i], endPositions[j] }); - break; - } - } - - } - } - if (!results.size()) { - return tokens; - } - - // Exclude tokens - pos = 0; - size_t n = 0; - size_t string_len = 0; - llama_tokens tokens_new; - auto model = llama_get_model(ctx); - for (n = 0; n < tokens.size(); ++n) { - str = llama_token_to_piece(model, tokens[n], true); - string_len = string_len + str.size(); - if (string_len <= results[pos].first) { - tokens_new.push_back(tokens[n]); - } - else if (string_len <= results[pos].second) { - continue; - } - else { - tokens_new.push_back(tokens[n]); - if (pos+1 < results.size()) { - pos++; - } - } - } - return tokens_new; - } - - - common_prefix server_tokens::get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact) const { - common_prefix token_prefix; - - size_t n = get_common_prefix_exact(b); // strict token match as a starting point - token_prefix.first = n; - token_prefix.second = n; - - if (!has_mtmd) { - token_prefix = find_common_text_token_prefix(ctx, this->tokens, b.tokens, n, exact); - token_prefix.first += n; - token_prefix.second += n; - return token_prefix; - } - size_t i = n; - size_t j = n; - llama_tokens a_list; - llama_tokens b_list; - while (i < size() && j < b.size()) { - llama_token ai = tokens[i]; - llama_token bi = b.tokens[j]; - if (ai != LLAMA_TOKEN_NULL) { - a_list.push_back(ai); - ++i; - } - if (bi != LLAMA_TOKEN_NULL) { - b_list.push_back(bi); - ++j; - } - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact); - // text match or empty - if (prefix.first == a_list.size() && prefix.second == b_list.size()) { - a_list.clear(); - b_list.clear(); - const auto& a_chunk = find_chunk(i); - const auto& b_chunk = b.find_chunk(j); - - GGML_ASSERT(a_chunk && b_chunk); - - const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); - const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - - const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); - const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - - // image match - if (id_ai == id_bi && n_tok_a == n_tok_b) { - GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen - i += n_tok_a; - j += n_tok_a; - prefix.first += n_tok_a; - prefix.second += n_tok_a; - token_prefix = common_prefix_add(prefix, token_prefix); - } - else { - // do no include image token prefix - // only return text token prefix - token_prefix = common_prefix_add(prefix, token_prefix); - return token_prefix; - } - } - else { - // text not match + // do no include image token prefix + // only return text token prefix token_prefix = common_prefix_add(prefix, token_prefix); return token_prefix; } } - } - common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact); - token_prefix = common_prefix_add(prefix, token_prefix); - - return token_prefix; - - } - - // take first n tokens of tokens list a - // find the common prefix between a and b - common_prefix server_tokens::get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact) const { - // not work for mtmd - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - auto tokens = get_text_tokens(); - if (n > tokens.size()) { - n = tokens.size(); - } - llama_tokens copy(tokens.begin(), tokens.begin() + n); - server_tokens a = server_tokens(copy, false); - return a.get_common_prefix(ctx, b, exact); - } - - // make sure all text tokens are within the vocab range - bool server_tokens::validate(const struct llama_context* ctx) const { - const llama_model* model = llama_get_model(ctx); - const llama_vocab* vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - - for (size_t i = 0; i < tokens.size(); ++i) { - auto& t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto& chunk = find_chunk(i); - size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); - i += n_tokens - 1; // will be +1 by the for loop - } - catch (const std::exception& e) { - return false; - } + else { + // text not match + token_prefix = common_prefix_add(prefix, token_prefix); + return token_prefix; } - else if (t < 0 || t >= n_vocab) { + } + } + common_prefix prefix = find_common_text_token_prefix(ctx, a_list, b_list, 0, exact); + token_prefix = common_prefix_add(prefix, token_prefix); + + return token_prefix; + +} + +// take first n tokens of tokens list a +// find the common prefix between a and b +common_prefix server_tokens::get_common_prefix_first_n(const llama_context* ctx, const server_tokens& b, size_t n, bool exact) const { + if (n > n_tokens()) { + n = n_tokens(); + } + server_tokens a = clone(); + a.keep_first(n); + return a.get_common_prefix(ctx, b, exact); +} + +// make sure all text tokens are within the vocab range +bool server_tokens::validate(const struct llama_context* ctx) const { + const llama_model* model = llama_get_model(ctx); + const llama_vocab* vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + auto& t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto& chunk = find_chunk(i); + size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); + i += n_tokens - 1; // will be +1 by the for loop + } + catch (const std::exception& e) { return false; } } - return true; + else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; +} + +// encode and decode the image chunk +int32_t server_tokens::process_chunk( + llama_context* ctx, + mtmd_context* mctx, + size_t idx, + llama_pos pos, + int32_t seq_id, + size_t& n_tokens_out) const { + const auto& chunk = find_chunk(idx); + const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + LLAMA_LOG_INFO("processing %s...\n", name); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past; // unused for now + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + chunk.get(), + pos, + seq_id, + n_batch, + true, // logits last + &new_n_past); + LLAMA_LOG_INFO("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); + if (result != 0) { + LLAMA_LOG_ERROR("mtmd_helper_eval failed with status %d", result); + n_tokens_out = 0; + return result; + } + n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); + return 0; +} + +server_tokens server_tokens::clone() const { + server_tokens res; + res.has_mtmd = has_mtmd; + res.tokens = tokens; + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + size_t idx = it->first; + const mtmd::input_chunk_ptr & chunk = it->second; + res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get())); + } + return res; +} + +// Keep the first n_keep and remove n_discard tokens from tokens +void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) { + + if (n_discard <= 0 || n_keep + n_discard >=n_tokens()) { + return; + } + server_tokens res = clone(); + keep_first(n_keep); + tokens.resize(res.n_tokens()-n_discard); + for (size_t i = n_keep + n_discard; i < res.n_tokens(); i++) { + tokens[i - n_discard] = res.tokens[i]; + } + for (auto it = res.map_idx_to_media.begin(); it != res.map_idx_to_media.end(); ++it) { + size_t idx = it->first; + if (idx >= n_keep+ n_discard) { + const mtmd::input_chunk_ptr & chunk = it->second; + map_idx_to_media[idx - n_discard] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get())); + } } - // encode and decode the image chunk - int32_t server_tokens::process_chunk( - llama_context* ctx, - mtmd_context* mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t& n_tokens_out) const { - const auto& chunk = find_chunk(idx); - const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - LLAMA_LOG_INFO("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - LLAMA_LOG_INFO("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LLAMA_LOG_ERROR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; +} + +// Similarity between prompt and cached +float server_tokens::get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const { + GGML_ASSERT(n_keep >= 0 && n_discard >= 0); + float sim_cur = 0; + if (n_keep == 0 && n_discard == 0) { + auto lcp_len = get_common_prefix(ctx, tokens); + sim_cur = get_slot_similarity(lcp_len.second, tokens.size(), size()); } - - // Keep the first n_keep and remove n_discard tokens from tokens - void server_tokens::discard_n_tokens(int32_t n_keep, int32_t n_discard) { - if (n_discard <= 0 || n_keep + n_discard >= size()) { - return; - } - - llama_tokens new_tokens = get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - int32_t token_size = (int32_t)size(); - new_tokens.resize(token_size - n_discard); - clear(); - insert(new_tokens); - + else { + // remove tokens due to context shift and compare + auto tokens_ctx_shift = tokens.clone(); // copy cache tokens + tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); + auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); + sim_cur = get_slot_similarity(lcp_len.second, tokens_ctx_shift.size(), size()); } + return sim_cur; +} - // Similarity between prompt and cached - float server_tokens::get_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const { - GGML_ASSERT(n_keep >= 0 && n_discard >= 0); - float sim_cur = 0; - if (n_keep == 0 && n_discard == 0) { - auto lcp_len = get_common_prefix(ctx, tokens); - sim_cur = get_slot_similarity(lcp_len.second, tokens.size(), size()); - } - else { - // remove tokens due to context shift and compare - auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); - auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); - sim_cur = get_slot_similarity(lcp_len.second, tokens_ctx_shift.size(), size()); - } - return sim_cur; +// Similarity between common part and cache +float server_tokens::get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const { + GGML_ASSERT(n_keep >= 0 && n_discard >= 0); + float sim_cur = 0; + if (n_keep == 0 && n_discard == 0) { + auto lcp_len = get_common_prefix(ctx, tokens); + sim_cur = (float)lcp_len.first / size(); } - - // Similarity between common part and cache - float server_tokens::get_cached_tokens_similarity(const llama_context* ctx, const server_tokens& tokens, int n_keep, int n_discard) const { - GGML_ASSERT(n_keep >= 0 && n_discard >= 0); - float sim_cur = 0; - if (n_keep == 0 && n_discard == 0) { - auto lcp_len = get_common_prefix(ctx, tokens); - sim_cur = (float)lcp_len.first / size(); - } - else { - // remove tokens due to context shift and compare - auto tokens_ctx_shift = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens - tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); - auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); - sim_cur = (float)lcp_len.first / size(); - } - return sim_cur; + else { + // remove tokens due to context shift and compare + auto tokens_ctx_shift = tokens.clone(); // copy cache tokens + tokens_ctx_shift.discard_n_tokens(n_keep, n_discard); + auto lcp_len = get_common_prefix(ctx, tokens_ctx_shift); + sim_cur = (float)lcp_len.first / size(); } + return sim_cur; +} // Computes FNV-1a hash of the data diff --git a/examples/server/server-common.h b/examples/server/server-common.h index e611808a..ebccf169 100644 --- a/examples/server/server-common.h +++ b/examples/server/server-common.h @@ -352,12 +352,19 @@ public: server_tokens(const llama_tokens& tokens, bool has_mtmd); - llama_pos pos_next() const; + // the next position after n_tokens. if n_tokens < 0, return the next position after all tokens. + llama_pos pos_next(int64_t n_tokens = -1) const; + + // number of tokens with position <= max_pos + size_t size_up_to_pos(llama_pos max_pos) const; int n_tokens() const { return tokens.size(); } + bool has_mtmd_data() { + return !map_idx_to_media.empty(); + } // for debugging std::string str() const; @@ -412,7 +419,7 @@ public: size_t get_common_prefix_exact(const server_tokens& b) const; - llama_tokens get_text_tokens_exclude_think(const llama_context* ctx, const thinking_tokens& think_token) const; + server_tokens get_tokens_exclude_think(const llama_context * ctx, const thinking_tokens & think_token) const; common_prefix get_common_prefix(const llama_context* ctx, const server_tokens& b, bool exact = false) const; // take first n tokens of tokens list a @@ -431,6 +438,8 @@ public: int32_t seq_id, size_t& n_tokens_out) const; + server_tokens clone() const; + // Keep the first n_keep and remove n_discard tokens from tokens void discard_n_tokens(int32_t n_keep, int32_t n_discard); diff --git a/examples/server/server-context.cpp b/examples/server/server-context.cpp index e768eb16..28c0cbc3 100644 --- a/examples/server/server-context.cpp +++ b/examples/server/server-context.cpp @@ -87,11 +87,6 @@ bool server_context::load_model(const gpt_params& params_) { } LOG_INFO("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - LOG_WARNING("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - //if (params.n_cache_reuse) { // params_base.n_cache_reuse = 0; // SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); @@ -298,9 +293,8 @@ void server_context::init() { } catch (const std::exception & e) { SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what()); - SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__); - SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__); - return; + SRV_ERR("%s: please consider enabling jinja via --jinja, or use a custom chat template via --chat-template\n", __func__); + SRV_ERR("%s: for example: --chat-template chatml\n", __func__); } // thinking is enabled if: @@ -375,6 +369,8 @@ void server_slot::reset() { generated_token_probs.clear(); checkpoint_pos = 0; + image_just_processed = false; + do_checkpoint = false; positional_bans.clear(); ban_phrases.clear(); @@ -463,6 +459,7 @@ void server_slot::release() { if (state == SLOT_STATE_PROCESSING) { t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; command = SLOT_COMMAND_RELEASE; + state = SLOT_STATE_IDLE; task.reset(); llama_decode_reset(); } @@ -697,7 +694,7 @@ std::pair server_context::calculate_slot_similarity(const } void server_context::copy_data_to_cached_prompt(const server_tokens & tokens, server_slot & slot) { - slot.server_cached_prompt.tokens = server_tokens(tokens.get_text_tokens(), false); // copy cache tokens + slot.server_cached_prompt.tokens = tokens.clone(); // copy cache tokens slot.server_cached_prompt.n_discarded_prompt = slot.n_discarded_prompt; slot.server_cached_prompt.n_kept_prompt = slot.n_kept_prompt; slot.server_cached_prompt.think_tokens = slot.params.think_tokens; @@ -722,13 +719,10 @@ server_slot* server_context::get_available_slot(const server_task& task) { if (cache_tokens.empty()) { continue; } - bool exclude_think = !cache_tokens.has_mtmd && slot.params.think_tokens.exclude; std::pair sim; - if (exclude_think) { - auto temp = slot.cache_tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens); - server_tokens cache_tokens_exclude_think = server_tokens(temp, false); - temp = task.tokens.get_text_tokens_exclude_think(slot.ctx, slot.params.think_tokens); - server_tokens prompt_tokens_exclude_think = server_tokens(temp, false); + if (slot.params.think_tokens.exclude) { + server_tokens cache_tokens_exclude_think = slot.cache_tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens); + server_tokens prompt_tokens_exclude_think = task.tokens.get_tokens_exclude_think(slot.ctx, slot.params.think_tokens); sim = calculate_slot_similarity(slot, ctx, cache_tokens_exclude_think, prompt_tokens_exclude_think); } else { @@ -780,13 +774,9 @@ server_slot* server_context::get_available_slot(const server_task& task) { float f_keep = 0; size_t cache_token_size = tokens.size(); if (!tokens.empty()) { - bool exclude_think = !tokens.has_mtmd && ret->params.think_tokens.exclude; - if (exclude_think) { - auto temp = tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens); - server_tokens cache_exclude_think = server_tokens(temp, false); - - temp = task.tokens.get_text_tokens_exclude_think(ret->ctx, ret->params.think_tokens); - server_tokens prompt_exclude_think = server_tokens(temp, false); + if (ret->params.think_tokens.exclude) { + server_tokens cache_exclude_think = tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens); + server_tokens prompt_exclude_think = task.tokens.get_tokens_exclude_think(ret->ctx, ret->params.think_tokens); cache_token_size = cache_exclude_think.size(); f_keep = calculate_slot_f_keep(*ret, ret->ctx, cache_exclude_think, prompt_exclude_think); @@ -807,9 +797,6 @@ server_slot* server_context::get_available_slot(const server_task& task) { // don't update the cache if the slot's context is above cache_ram_n_min update_cache = update_cache && cache_token_size >= cache_ram_n_min; - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - LLAMA_LOG_INFO("======== Prompt cache: cache size: %d, n_keep: %d, n_discarded_prompt: %d, cache_ram_n_min: %d, f_keep: %.2f, cache_ram_similarity: %.2f\n", (int)tokens.size(), ret->n_kept_prompt, ret->n_discarded_prompt, cache_ram_n_min, f_keep, cache_ram_similarity); if (update_cache) { @@ -829,7 +816,7 @@ server_slot* server_context::get_available_slot(const server_task& task) { ret->prompt_load(*prompt_cache, task.tokens); prompt_cache->update(); - ret->cache_tokens = server_tokens(ret->server_cached_prompt.tokens.get_text_tokens(), false); // recover cache tokens + ret->cache_tokens = ret->server_cached_prompt.tokens.clone(); // recover cache tokens ret->n_discarded_prompt = ret->server_cached_prompt.n_discarded_prompt; ret->n_kept_prompt = ret->server_cached_prompt.n_kept_prompt; @@ -1335,11 +1322,14 @@ bool server_context::launch_slot_with_task(server_slot& slot, server_task& task) // - the model architecture is marked as recurrent or hybrid // // TODO: try to make this conditional on the context or the memory module, instead of the model type - // do_checkpoint = do_checkpoint && llama_model_has_recurrent(model); params_base.do_checkpoint = do_checkpoint; if (slot.n_buffer != 0) { - LLAMA_LOG_WARN("Recurrent model does not support banned strings.\n"); + LLAMA_LOG_WARN("banned strings is not supported by recurrent model, it will be disabled.\n"); } + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + LOG_WARNING("%s\n", "ctx_shift is not supported by recurrent model, it will be disabled"); + } } { const auto& stop = data.find("stop"); @@ -1713,18 +1703,18 @@ void server_context::send_error(const int id_task, const int id_multi, const std {"error", error}, }); - server_task_result res; - res.id = id_task; - res.id_multi = id_multi; - res.stop = false; - res.error = true; - res.data = format_error_response(error, type); - - queue_results.send(res); + auto res = std::make_unique(); + res->id = id_task; + res->id_multi = id_multi; + res->stop = false; + res->error = true; + res->err_type = type; + res->err_msg = error; + queue_results.send(std::move(res)); } // if multimodal is enabled, send an error and return false -bool server_context::ensure_no_mtmd(const int id_task) { +bool server_context::check_no_mtmd(const int id_task) { if (mctx) { int id_multi = 0; send_error(id_task, id_multi, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); @@ -2127,9 +2117,6 @@ void server_context::process_single_task(server_task&& task) { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - if (!ensure_no_mtmd(task.id)) { - break; - } int id_slot = task.data.at("id_slot"); server_slot* slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2142,7 +2129,9 @@ void server_context::process_single_task(server_task&& task) { queue_tasks.defer(std::move(task)); break; } - + if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) { + break; + } const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); @@ -2171,7 +2160,6 @@ void server_context::process_single_task(server_task&& task) { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - if (!ensure_no_mtmd(task.id)) break; int id_slot = task.data.at("id_slot"); server_slot* slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2184,7 +2172,9 @@ void server_context::process_single_task(server_task&& task) { queue_tasks.defer(std::move(task)); break; } - + if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) { + break; + } const int64_t t_start = ggml_time_us(); std::string filename = task.data.at("filename"); @@ -2199,7 +2189,9 @@ void server_context::process_single_task(server_task&& task) { break; } slot->cache_tokens.resize(token_count); - + if (mctx) { + slot->cache_tokens.has_mtmd = true; + } const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2220,7 +2212,6 @@ void server_context::process_single_task(server_task&& task) { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - if (!ensure_no_mtmd(task.id)) break; int id_slot = task.data.at("id_slot"); server_slot* slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2233,7 +2224,9 @@ void server_context::process_single_task(server_task&& task) { queue_tasks.defer(std::move(task)); break; } - + if (slot->cache_tokens.has_mtmd_data() && !check_no_mtmd(task.id)) { + break; + } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); @@ -2489,13 +2482,59 @@ void server_context::print_tokens(const server_tokens& prompt, const server_toke } void server_context::discard_n_kv_and_cache_tokens(llama_context* ctx, server_slot& slot, int32_t n_keep, int32_t n_discard) { - llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + auto kv_keep = slot.cache_tokens.pos_next(n_keep); + auto kv_discard = slot.cache_tokens.pos_next(n_keep + n_discard) - kv_keep; + auto kv_past = slot.cache_tokens.pos_next(slot.n_past); + int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); + const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); + llama_kv_cache_seq_rm(ctx, slot.id, kv_keep, kv_keep + kv_discard); + llama_kv_cache_seq_add(ctx, slot.id, kv_keep + kv_discard, kv_past, -kv_discard); if (slot.params.cache_prompt) { slot.cache_tokens.discard_n_tokens(n_keep, n_discard); } } + +inline static bool tokens_support_context_shift(const server_tokens & tokens, int32_t n_keep, + int32_t n_discard) { + bool can_shift = !tokens.has_mtmd; + if (tokens.has_mtmd) { + can_shift = true; + if (n_keep > 0 && n_keep<= tokens.n_tokens()) { + can_shift = tokens[n_keep - 1] != LLAMA_TOKEN_NULL; + } + if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) { + can_shift = can_shift && tokens[n_discard + n_keep - 1] != LLAMA_TOKEN_NULL; + } + } + return can_shift; +} + +inline static void adjust_n_to_support_context_shift(const server_tokens & tokens, int32_t & n_keep, + int32_t & n_discard) { + if (!tokens.has_mtmd) { + return; + } + if (n_keep > 0 && n_keep <= tokens.n_tokens()) { + while (tokens[n_keep - 1] == LLAMA_TOKEN_NULL) { + n_keep--; + if (n_keep<1 || n_keep>tokens.size()) { + break; + } + } + } + if (n_discard + n_keep > 0 && n_discard + n_keep <= tokens.n_tokens()) { + while (tokens[n_discard + n_keep - 1] == LLAMA_TOKEN_NULL) { + n_discard++; + if (n_discard + n_keep<1 || n_discard + n_keep>tokens.size()) { + break; + } + } + } + +} + + // convert keep first few and discard next tokens in a to b void server_context::context_shift_find_n_tokens(llama_context* ctx, const server_tokens& a, const server_tokens& b, int32_t n_keep, int32_t n_discard, int32_t& n_kept, int32_t& n_discarded, bool exact) { @@ -2519,7 +2558,10 @@ void server_context::context_shift_prompt(llama_context* ctx, server_slot& slot, int n_keep = std::max(0, slot.params.n_keep + add_bos_token); const int n_left = slot.n_ctx - n_keep; int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - + adjust_n_to_support_context_shift(slot.prompt_tokens, n_keep, n_discard); + if (n_discard<=0 || !tokens_support_context_shift(slot.prompt_tokens, n_keep, n_discard)) { + return; + } int n_discard_prompt = 0; // we still need to truncate input since we have not discarded enough tokens while (slot.n_prompt_tokens - slot.n_discarded_prompt >= slot.n_ctx) { @@ -2598,15 +2640,11 @@ void server_context::context_shift() { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.print_timings(); slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); continue; } - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } // Shift context int n_keep = slot.params.n_keep < 0 ? slot.prompt_tokens.size() : slot.params.n_keep; if (add_bos_token) { @@ -2614,11 +2652,12 @@ void server_context::context_shift() { } n_keep = std::min(slot.n_ctx - 4, n_keep); - const int n_left = (int)system_tokens.size() + slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + const int32_t n_left = (int)system_tokens.size() + slot.n_past - n_keep; + int32_t n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); int32_t n_kept; int32_t n_discard_cache; - if (n_discard > 0) { + adjust_n_to_support_context_shift(slot.cache_tokens, n_keep, n_discard); + if (n_discard > 0 && tokens_support_context_shift(slot.cache_tokens, n_keep, n_discard)) { context_shift_find_n_tokens(ctx, slot.prompt_tokens, slot.cache_tokens, n_keep, n_discard, n_kept, n_discard_cache); LOG_INFO("slot context shift", { @@ -2725,21 +2764,21 @@ void server_context::create_checkpoint_at_interval(server_slot & slot, const gp if (params_base.do_checkpoint && params_base.ctx_checkpoints_interval > 0) { auto pos = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); if (slot.checkpoint_pos + params_base.ctx_checkpoints_interval <= 1 + pos) { - create_checkpoint(slot); - slot.checkpoint_pos = pos; + bool created = create_checkpoint(slot); + if (created) { + slot.checkpoint_pos = pos; + } } } } void server_context::apply_checkpoint(server_slot & slot) { - const auto pos_min_thold = std::max(0, slot.n_past - 1); - if (!mctx && slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) { + llama_pos pos_next = slot.cache_tokens.pos_next(slot.n_past); + const auto pos_min_thold = std::max(0, pos_next - 1); + if (slot.n_past > 0 && slot.n_past < slot.cache_tokens.n_tokens()) { int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.cache_tokens.has_mtmd); - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int)slot.cache_tokens.size(), slot.id, pos_min); // search for a context checkpoint @@ -2765,8 +2804,10 @@ void server_context::apply_checkpoint(server_slot & slot) { do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt + 1, it->pos_max_prompt)); + slot.n_past = std::min(slot.n_past, std::max(it->pos_min+1, it->pos_max)); + slot.n_past = slot.cache_tokens.size_up_to_pos(slot.n_past-1); + slot.n_past_prompt = std::min(slot.n_past_prompt, std::max(it->pos_min_prompt+1, it->pos_max_prompt)); + slot.n_past_prompt = slot.prompt_tokens.size_up_to_pos(slot.n_past_prompt-1); SLT_WRN(slot, "restored context checkpoint took %.2f ms (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", (ggml_time_us() - t_start) / 1000.0, it->pos_min, it->pos_max, (float)checkpoint_size / 1024 / 1024); } } @@ -2794,8 +2835,8 @@ void server_context::apply_checkpoint(server_slot & slot) { } } -void server_context::create_checkpoint(server_slot & slot) { - bool do_checkpoint = true; +bool server_context::create_checkpoint(server_slot & slot) { + bool do_checkpoint = !slot.image_just_processed; int32_t pos_min = llama_kv_cache_seq_pos_min(slot.ctx, slot.id); const auto pos_max = llama_kv_cache_seq_pos_max(slot.ctx, slot.id); @@ -2833,6 +2874,7 @@ void server_context::create_checkpoint(server_slot & slot) { (int)slot.server_cached_prompt.checkpoints.size(), params_base.ctx_checkpoints_n, cur.pos_min, cur.pos_max, (float)cur.data.size() / 1024 / 1024, (ggml_time_us() - t_start) / 1000.0); } + return do_checkpoint; } void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t n_batch, int32_t & batch_type) { @@ -2935,12 +2977,6 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.release(); continue; } - if (mctx) { - // we should never reach this because params.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - context_shift_prompt(ctx, slot); slot.truncated = true; LOG_VERBOSE("input truncated", { @@ -3100,7 +3136,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot.n_past += n_tokens_out; slot.n_past_prompt += n_tokens_out; slot.n_prompt_tokens_processed += n_tokens_out; - + slot.image_just_processed = true; // do not checkpoint right after an image chunk } @@ -3137,7 +3173,7 @@ void server_context::batch_pending_prompt(const int32_t n_ubatch, const int32_t slot_npast++; slot.n_past_prompt++; slot.n_past++; - slot.do_checkpoint = false; + slot.image_just_processed = false; if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) { slot.do_checkpoint = true; break; @@ -3286,6 +3322,8 @@ void server_context::speculative_decoding_accept() { if (slot.n_buffer == 0 || !params_base.can_ban_phrases) { if (!process_token(result, slot)) { // release slot because of stop condition + slot.cache_tokens.push_back(slot.sampled); + slot.n_past++; send_final_response(slot); release_slot_after_final_response(slot); break; @@ -3338,6 +3376,8 @@ void server_context::send_token_results(completion_token_outputs& results, serve if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { continue; } + slot.cache_tokens.push_back(slot.sampled); + slot.n_past++; send_final_response(slot); release_slot_after_final_response(slot); released = true; @@ -3349,6 +3389,8 @@ void server_context::send_token_results(completion_token_outputs& results, serve } if (!released && slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { + slot.cache_tokens.push_back(slot.sampled); + slot.n_past++; send_final_response(slot); release_slot_after_final_response(slot); } @@ -3381,10 +3423,10 @@ inline int32_t check_ban_phrase(server_slot& slot) { if (start != std::string::npos) { if (start < best_start) { best_start = start; - found = true; - } + found = true; } } + } // 2. Check regex for (const auto& pattern : slot.ban_regex) { @@ -3424,8 +3466,8 @@ inline int32_t check_ban_phrase(server_slot& slot) { if (best_start >= token_offsets[i] && best_start < token_offsets[i] + len) { token_idx = (int32_t)i; break; - } - } + } +} if (token_idx != -1) { int32_t abs_pos = slot.n_past - (int32_t)slot.token_buffer.size() + 1 + token_idx; @@ -3449,7 +3491,7 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { llama_token banned_tok = result->tok; if (n == 0) { - LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", + LLAMA_LOG_DEBUG("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", ban_pos, banned_tok, result->text_to_send.c_str()); } @@ -3462,11 +3504,11 @@ inline void rewind_context(server_slot& slot, int32_t ban_pos) { } int32_t n_rewind_total = (slot.n_past + 1) - ban_pos; - + size_t n_keep_cache = 0; if (ban_pos > 0) { n_keep_cache = (size_t)(ban_pos - 1); - } +} if (n_keep_cache > slot.cache_tokens.size()) { n_keep_cache = slot.cache_tokens.size(); @@ -3516,7 +3558,7 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ int32_t n_keep_buffer = ban_pos - buffer_start_pos; if (n_keep_buffer < 0) n_keep_buffer = 0; n_rewind = (int32_t)slot.token_buffer.size() - n_keep_buffer; - } + } } bool allow_rewind = true; @@ -3559,16 +3601,16 @@ void server_context::buffer_and_check_string_ban(server_slot & slot, completion_ send_token_results(slot.token_buffer, slot, 1); } if (slot.sparams.adaptive_target >= 0.0f) { - sent_results = true; - } + sent_results = true; + } } else { // buffer the result, wait for more tokens to validate string slot.sampled = result.tok; } if (slot.sparams.adaptive_target >= 0.0f) { - slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind; - } + slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind; +} } void server_context::process_batch_tokens(int32_t & n_batch) { diff --git a/examples/server/server-context.h b/examples/server/server-context.h index 38675d81..f04c8a54 100644 --- a/examples/server/server-context.h +++ b/examples/server/server-context.h @@ -110,6 +110,7 @@ struct server_slot { size_t checkpoint_pos = 0; bool do_checkpoint = false; + bool image_just_processed = false; // sampling llama_token sampled; // in speculative mode, this is the last accepted token @@ -302,7 +303,7 @@ struct server_context { void send_error(const int id_task, const int id_multi, const std::string& error, const enum error_type type = ERROR_TYPE_SERVER); // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task); + bool check_no_mtmd(const int id_task); void send_partial_response(server_slot& slot, completion_token_output tkn); @@ -363,7 +364,7 @@ struct server_context { // Re-aggregates all active vectors and updates the model state bool apply_control_vectors_internal(); - void create_checkpoint(server_slot & slot); + bool create_checkpoint(server_slot & slot); void apply_checkpoint(server_slot & slot); diff --git a/examples/server/server-task.cpp b/examples/server/server-task.cpp index 5f3d30cf..421f1620 100644 --- a/examples/server/server-task.cpp +++ b/examples/server/server-task.cpp @@ -1081,12 +1081,12 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token server_tokens prompt_tokens; server_tokens tokens_new_ex; if (think_tokens.exclude) { - prompt_tokens = server_tokens(prompt.tokens.get_text_tokens_exclude_think(ctx, think_tokens), false); - tokens_new_ex = server_tokens(tokens_new.get_text_tokens_exclude_think(ctx, think_tokens), false); + prompt_tokens = prompt.tokens.get_tokens_exclude_think(ctx, think_tokens); + tokens_new_ex = tokens_new.get_tokens_exclude_think(ctx, think_tokens); } else { - prompt_tokens = std::move(prompt.tokens); //server_tokens(prompt.tokens.get_text_tokens(), false); - tokens_new_ex = server_tokens(tokens_new.get_text_tokens(), false); + prompt_tokens = std::move(prompt.tokens); + tokens_new_ex = tokens_new.clone(); } const auto lcp_best = prompt_tokens.get_common_prefix(ctx, tokens_new_ex); float f_keep_best = float(lcp_best.second) / prompt_tokens.size(); @@ -1099,7 +1099,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token for (auto it = states.begin(); it != states.end(); ++it) { server_tokens tokens; if (think_tokens.exclude) { - tokens = server_tokens(it->tokens.get_text_tokens_exclude_think(ctx, think_tokens), false); + tokens = it->tokens.get_tokens_exclude_think(ctx, think_tokens); } else { tokens = std::move(it->tokens); @@ -1136,7 +1136,7 @@ bool server_prompt_cache::load(server_prompt& prompt, const server_tokens& token server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t state_size) { for (auto it = states.begin(); it != states.end();) { - auto tokens_ctx_shift = server_tokens(prompt.tokens.get_text_tokens(), false); // copy cache tokens + auto tokens_ctx_shift = prompt.tokens.clone(); // copy cache tokens tokens_ctx_shift.discard_n_tokens(prompt.n_kept_prompt, prompt.n_discarded_prompt); auto prefix = it->tokens.get_common_prefix(ctx, tokens_ctx_shift); const size_t len = prefix.first; @@ -1177,7 +1177,7 @@ server_prompt* server_prompt_cache::alloc(const server_prompt& prompt, size_t st // TODO: for some reason we can't copy server_tokens, so we have to do this workaround auto& cur = states.emplace_back(); cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.tokens =*/ prompt.tokens.clone(), /*.n_keep =*/ prompt.n_kept_prompt, /*.n_discarded_prompt =*/ prompt.n_discarded_prompt, /*.think_tokens =*/ prompt.think_tokens, diff --git a/examples/server/server-task.h b/examples/server/server-task.h index f3d3705f..1cce0907 100644 --- a/examples/server/server-task.h +++ b/examples/server/server-task.h @@ -371,6 +371,16 @@ struct server_prompt { return tokens.size(); } + server_prompt clone() const { + return server_prompt{ + tokens.clone(), + n_kept_prompt, + n_discarded_prompt, + think_tokens, + data, + checkpoints + }; + } }; struct server_prompt_cache { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 120c7846..b07b15a4 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -125,7 +125,7 @@ ggml_cgraph * llm_build_context::build_k_shift() { GGML_ASSERT(kv_self.size == n_ctx); - const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE // @ngxson : this is a workaround // for M-RoPE, we want to rotate the whole vector when doing KV shift // a normal RoPE should work, we just need to use the correct ordering diff --git a/src/llama.cpp b/src/llama.cpp index c399ba23..961bd6bf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4058,12 +4058,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); } +static bool get_can_shift(struct llama_context & lctx) { + bool no_shift = lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA; // not supported due to MLA + no_shift = no_shift || lctx.model.hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE; + return !no_shift; +} + static int32_t llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; // apply K-shift if needed if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { - if (lctx.model.arch == LLM_ARCH_DEEPSEEK2 || lctx.model.arch == LLM_ARCH_GLM_DSA) { // not supported due to MLA + if (!get_can_shift(lctx)) { return 1; }