diff --git a/common/arg.cpp b/common/arg.cpp index a859aac4fe..1b6884781d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1615,7 +1615,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), [](common_params & params, const std::string & value) { const auto sampler_names = string_split(value, ';'); - params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + params.sampling.samplers = common_sampler_types_from_names(sampler_names); params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS; } ).set_sampling()); diff --git a/common/common.cpp b/common/common.cpp index b6a7626f2a..b01772e1cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1148,7 +1148,7 @@ static void common_init_sampler_from_model( if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) { const std::vector sampler_names = string_split(std::string(buf), ';'); if (!sampler_names.empty()) { - sparams.samplers = common_sampler_types_from_names(sampler_names, true); + sparams.samplers = common_sampler_types_from_names(sampler_names); } } } diff --git a/common/sampling.cpp b/common/sampling.cpp index 85f8ed50b3..c537f33503 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -769,54 +769,63 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { } } -std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - { "dry", COMMON_SAMPLER_TYPE_DRY }, - { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, - { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, - { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, - { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, - { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, - { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, - { "xtc", COMMON_SAMPLER_TYPE_XTC }, - { "infill", COMMON_SAMPLER_TYPE_INFILL }, - { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, - { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, - }; - - // since samplers names are written multiple ways - // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - { "top-k", COMMON_SAMPLER_TYPE_TOP_K }, - { "top-p", COMMON_SAMPLER_TYPE_TOP_P }, - { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, - { "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, - { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, - { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, - { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, - { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, - { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, - { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, - { "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, - }; +std::vector common_sampler_types_from_names(const std::vector & names) { + // sampler names can be written multiple ways; generate aliases from canonical names + static const auto sampler_name_map = []{ + // canonical sampler name mapping + std::unordered_map canonical_name_map { + { "dry", COMMON_SAMPLER_TYPE_DRY }, + { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, + { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, + { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, + { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, + { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "xtc", COMMON_SAMPLER_TYPE_XTC }, + { "infill", COMMON_SAMPLER_TYPE_INFILL }, + { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P } + }; + std::unordered_map alias_name_map; + for (const auto & entry : canonical_name_map) { + const std::string & canonical = entry.first; + if (canonical.find('_') == std::string::npos) { + continue; + } + // kebab-case: "top-k", "min-p", etc. + { + std::string kebab_case = canonical; + std::replace(kebab_case.begin(), kebab_case.end(), '_', '-'); + alias_name_map.insert({kebab_case, entry.second}); + } + // no dash: "topk", "minp", etc. + { + std::string no_dash = canonical; + no_dash.erase(std::remove(no_dash.begin(), no_dash.end(), '_'), no_dash.end()); + alias_name_map.insert({no_dash, entry.second}); + } + } + // misc. aliases + alias_name_map.insert({"nucleus", COMMON_SAMPLER_TYPE_TOP_P}); + alias_name_map.insert({"temp", COMMON_SAMPLER_TYPE_TEMPERATURE}); + alias_name_map.insert({"typ", COMMON_SAMPLER_TYPE_TYPICAL_P}); + // include aliases + canonical names in the complete mapping + alias_name_map.merge(canonical_name_map); + return alias_name_map; + }(); std::vector samplers; samplers.reserve(names.size()); for (const auto & name : names) { - auto sampler = sampler_canonical_name_map.find(name); - if (sampler != sampler_canonical_name_map.end()) { + std::string name_lower = name; + std::transform(name_lower.begin(), name_lower.end(), name_lower.begin(), ::tolower); + auto sampler = sampler_name_map.find(name_lower); + if (sampler != sampler_name_map.end()) { samplers.push_back(sampler->second); continue; } - if (allow_alt_names) { - sampler = sampler_alt_name_map.find(name); - if (sampler != sampler_alt_name_map.end()) { - samplers.push_back(sampler->second); - continue; - } - } - LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str()); + LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name_lower.c_str()); } return samplers; diff --git a/common/sampling.h b/common/sampling.h index 19cbbbaba3..4191988bb8 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -109,7 +109,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, char common_sampler_type_to_chr(enum common_sampler_type cnstr); std::string common_sampler_type_to_str(enum common_sampler_type cnstr); -std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector common_sampler_types_from_names(const std::vector & names); std::vector common_sampler_types_from_chars(const std::string & chars); llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 33de2e4d9c..842be2ad3d 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -605,7 +605,7 @@ task_params server_task::params_from_json_cmpl( const auto samplers = data.find("samplers"); if (samplers != data.end()) { if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + params.sampling.samplers = common_sampler_types_from_names(*samplers); } else if (samplers->is_string()){ params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); }