dflash: refactor draft model conversion (#25110)

* dflash: refactor draft model conversion

* apply fix for eagle3 convert
This commit is contained in:
Ruixiang Wang 2026-06-28 20:31:48 +02:00 committed by GitHub
parent c818263f2a
commit fa72bc6826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 31 additions and 17 deletions

View File

@ -73,7 +73,7 @@ class LlamaModel(TextModel):
target_num_layers = target_config["num_hidden_layers"]
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
self.gguf_writer.add_target_layers(target_layers)
# target_hidden_size: prefer eagle3 config, fallback to target config
if eagle3_raw_config.get("target_hidden_size") is not None:
@ -83,12 +83,12 @@ class LlamaModel(TextModel):
target_hidden_size = target_config["hidden_size"]
src = "target model config"
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
self.gguf_writer.add_target_hidden_size(target_hidden_size)
# norm_before_residual (RedHat-style eagle3 specific)
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
self.gguf_writer.add_norm_before_residual(norm_before_residual)
def set_vocab(self):
# eagle3: use tokenizer from target model if provided

View File

@ -643,21 +643,21 @@ class DFlashModel(Qwen3Model):
super().set_vocab()
self.dir_model = original_dir
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
def set_gguf_parameters(self):
super().set_gguf_parameters()
block_size = self.hparams.get("block_size", 16)
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size)
self.gguf_writer.add_block_size(block_size)
dflash_config = self.hparams.get("dflash_config", {})
target_layer_ids = dflash_config.get("target_layer_ids", [])
if target_layer_ids:
extract_layer_ids = [i + 1 for i in target_layer_ids]
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", extract_layer_ids)
mask_token_id = dflash_config.get("mask_token_id", None)
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
self.gguf_writer.add_target_layers(extract_layer_ids)
use_sliding_window = self.hparams.get("use_sliding_window", False)
sliding_window = self.hparams.get("sliding_window")
@ -667,13 +667,9 @@ class DFlashModel(Qwen3Model):
self.gguf_writer.add_sliding_window(sliding_window)
self.gguf_writer.add_sliding_window_pattern(is_swa)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "fc.weight":
yield (name, data_torch)
return
if name == "hidden_norm.weight":
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ENC_OUTPUT_NORM), data_torch)
return
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if not name.startswith("model."):
name = "model." + name
yield from super().modify_tensors(data_torch, name, bid)
return super().filter_tensors((name, gen))

View File

@ -156,6 +156,7 @@ class Keys:
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
TARGET_LAYERS = "{arch}.target_layers"
TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
BLOCK_SIZE = "{arch}.block_size"
NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual"
class Attention:

View File

@ -940,6 +940,18 @@ class GGUFWriter:
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
def add_block_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_SIZE.format(arch=self.arch), value)
def add_target_layers(self, value: Sequence[int]) -> None:
self.add_array(Keys.LLM.TARGET_LAYERS.format(arch=self.arch), value)
def add_target_hidden_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.TARGET_HIDDEN_SIZE.format(arch=self.arch), value)
def add_norm_before_residual(self, value: bool) -> None:
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)

View File

@ -1283,6 +1283,11 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
"layer_norm", # neobert
"model.hidden_norm", # dflash
),
MODEL_TENSOR.FC: (
"model.fc", # dflash
),
MODEL_TENSOR.CLS: (