From 4fb16eccce5e451b40014355f97374d692480a4d Mon Sep 17 00:00:00 2001 From: Mikhail Podvitskii Date: Tue, 2 Jun 2026 21:11:12 +0200 Subject: [PATCH] model: add Mellum architecture (#23966) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * model: support for Mellum architecture * model: improve mellum.py formatting * model: improve mellum.py formatting once again * deps: downgrade transformers to 4.57.6 (to fix CI) * deps: remove huggingface_hub dependency * deps: remove huggingface_hub from test requirements --------- Co-authored-by: Sigbjørn Skjæret --- README.md | 1 + conversion/__init__.py | 1 + conversion/base.py | 3 + conversion/mellum.py | 61 +++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 19 ++ pyproject.toml | 2 +- .../requirements-convert_legacy_llama.txt | 2 +- requirements/requirements-tool_bench.txt | 1 - src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-model-saver.cpp | 1 + src/llama-model.cpp | 10 +- src/llama-model.h | 1 + src/llama-vocab.cpp | 4 + src/llama-vocab.h | 1 + src/models/mellum.cpp | 225 ++++++++++++++++++ src/models/models.h | 12 + tests/test-llama-archs.cpp | 1 + tools/server/tests/requirements.txt | 1 - 20 files changed, 344 insertions(+), 5 deletions(-) create mode 100644 conversion/mellum.py create mode 100644 src/models/mellum.cpp diff --git a/README.md b/README.md index dbe2c363a5..ae37b13e12 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) - [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7) - [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86) +- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum) #### Multimodal diff --git a/conversion/__init__.py b/conversion/__init__.py index 3ceb2d3853..8415c65f94 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -135,6 +135,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "Mamba2ForCausalLM": "mamba", "MambaForCausalLM": "mamba", "MambaLMHeadModel": "mamba", + "MellumForCausalLM": "mellum", "MiMoV2FlashForCausalLM": "mimo", "MiMoV2ForCausalLM": "mimo", "MiniCPM3ForCausalLM": "minicpm", diff --git a/conversion/base.py b/conversion/base.py index 69bc472b72..408e209aa8 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -1663,6 +1663,9 @@ class TextModel(ModelBase): if chkhsh == "789696f5946cc0fc59371f39f6097cafed196b3acded6140432f26bbb1ae1669": # ref: https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2 res = "granite-embed-multi-311m" + if chkhsh == "9dcf830ee9990cdbf78cc523a5f7bd9ad8f3f9890c2d3581d2785ad10f07049d": + # ref: https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base + res = "mellum2" if res is None: logger.warning("\n") diff --git a/conversion/mellum.py b/conversion/mellum.py new file mode 100644 index 0000000000..79bc6755cc --- /dev/null +++ b/conversion/mellum.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Iterable, TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from torch import Tensor + +from .base import ModelBase, TextModel, gguf, logger + + +@ModelBase.register("MellumForCausalLM") +class MellumModel(TextModel): + model_arch = gguf.MODEL_ARCH.MELLUM + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") + + use_sliding_window = self.hparams.get("use_sliding_window") + sliding_window = self.hparams.get("sliding_window") + if (use_sliding_window is True or use_sliding_window is None) and sliding_window is not None: + self.gguf_writer.add_sliding_window(sliding_window) + logger.info(f"gguf: sliding window = {sliding_window}") + self.gguf_writer.add_sliding_window_pattern([t == "sliding_attention" for t in self.hparams["layer_types"]]) + logger.info(f"gguf: sliding window pattern length = {len(self.hparams['layer_types'])}") + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.find("experts") != -1: + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + yield from super().modify_tensors(data_torch, merged_name, bid) + return + else: + return + + yield from super().modify_tensors(data_torch, name, bid) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 00e5888970..b4c8a7cf00 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -160,6 +160,7 @@ models = [ {"name": "minicpm5", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openbmb/MiniCPM5-1B"}, {"name": "granite-embed-multi-97m", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-97m-multilingual-r2", }, {"name": "granite-embed-multi-311m", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-311m-multilingual-r2", }, + {"name": "mellum2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum2-12B-A2.5B-Base"}, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index fc54063fea..207cc2a193 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -510,6 +510,7 @@ class MODEL_ARCH(IntEnum): MAINCODER = auto() KIMI_LINEAR = auto() TALKIE = auto() + MELLUM = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -1030,6 +1031,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MAINCODER: "maincoder", MODEL_ARCH.KIMI_LINEAR: "kimi-linear", MODEL_ARCH.TALKIE: "talkie", + MODEL_ARCH.MELLUM: "mellum", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -4093,6 +4095,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_SCALE, ], + MODEL_ARCH.MELLUM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/pyproject.toml b/pyproject.toml index e4f8c86b95..46cf68ca1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = '>=3.10,<3.15' dependencies = [ 'numpy (>=1.26.4,<3.0.0)', 'sentencepiece (>=0.1.98,<0.3.0)', - 'transformers (==5.5.1)', + 'transformers (==4.57.6)', 'protobuf (>=4.21.0,<5.0.0)', 'torch (>=2.6.0,<3.0.0)', 'gguf @ ./gguf-py', diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 18d3980106..28221fad0c 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,7 +1,7 @@ numpy~=1.26.4 sentencepiece>=0.1.98,<0.3.0 -transformers==5.5.1 +transformers==4.57.6 gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index 17d6b866c6..3e6f824165 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -1,6 +1,5 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub>=1.5.0,<2.0 matplotlib~=3.10.0 numpy~=1.26.4 openai~=2.14.0 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9d5a7b6e9e..8f462396f4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -135,6 +135,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_TALKIE, "talkie" }, + { LLM_ARCH_MELLUM, "mellum" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; diff --git a/src/llama-arch.h b/src/llama-arch.h index 233b29de67..b47c05d90d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -139,6 +139,7 @@ enum llm_arch { LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, LLM_ARCH_TALKIE, + LLM_ARCH_MELLUM, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 528e4c9c06..539d17eebc 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -29,6 +29,7 @@ bool llama_model_saver_supports_arch(llm_arch arch) { case LLM_ARCH_APERTUS: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_MELLUM: return false; default: return true; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 42d104d22c..bd5635ed45 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -81,6 +81,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_mpt(params); case LLM_ARCH_STABLELM: return new llama_model_stablelm(params); + case LLM_ARCH_MELLUM: + return new llama_model_mellum(params); case LLM_ARCH_QWEN: return new llama_model_qwen(params); case LLM_ARCH_QWEN2: @@ -764,6 +766,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_12B_A2_5B: return "12B.A2.5B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; @@ -1816,7 +1819,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + if (arch == LLM_ARCH_MELLUM || + arch == LLM_ARCH_QWEN3MOE || + arch == LLM_ARCH_OPENAI_MOE || + arch == LLM_ARCH_QWEN3VLMOE || + arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -2404,6 +2411,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: case LLM_ARCH_TALKIE: + case LLM_ARCH_MELLUM: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index d510d4a938..a561374ed9 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_12B_A2_5B, LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 79f14ac248..5205023981 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -353,6 +353,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: case LLAMA_VOCAB_PRE_TYPE_MINERVA: + case LLAMA_VOCAB_PRE_TYPE_MELLUM2: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -2325,6 +2326,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "mellum2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MELLUM2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 416eab522b..b3991b5322 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -63,6 +63,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI = 54, + LLAMA_VOCAB_PRE_TYPE_MELLUM2 = 55, }; struct LLM_KV; diff --git a/src/models/mellum.cpp b/src/models/mellum.cpp new file mode 100644 index 0000000000..a2372399bb --- /dev/null +++ b/src/models/mellum.cpp @@ -0,0 +1,225 @@ +#include "models.h" + +void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + uint32_t swa_period = 4; + const auto res = ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + if (res) { + hparams.set_swa_pattern(swa_period); + } else { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + } + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_12B_A2_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mellum::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for Mellum"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for Mellum"); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_mellum::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } + return std::make_unique>(*this, params); +} + +template +llama_model_mellum::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template struct llama_model_mellum::graph; +template struct llama_model_mellum::graph; diff --git a/src/models/models.h b/src/models/models.h index cbef040870..866e0d0be3 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -411,6 +411,18 @@ struct llama_model_stablelm : public llama_model_base { std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_mellum : public llama_model_base { + llama_model_mellum(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; struct llama_model_qwen : public llama_model_base { llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 1def7faff6..4fe585e29a 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -357,6 +357,7 @@ static bool moe_mandatory(const llm_arch arch) { case LLM_ARCH_KIMI_LINEAR: case LLM_ARCH_STEP35: case LLM_ARCH_MISTRAL4: + case LLM_ARCH_MELLUM: return true; default: return false; diff --git a/tools/server/tests/requirements.txt b/tools/server/tests/requirements.txt index 92d27e2a13..ca7a0281fa 100644 --- a/tools/server/tests/requirements.txt +++ b/tools/server/tests/requirements.txt @@ -1,6 +1,5 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub>=1.5.0,<2.0 numpy~=1.26.4 openai~=2.14.0 prometheus-client~=0.20.0