mtmd: model: unlimited-ocr: converter + parity test (#24969)

This commit is contained in:
Saba Fallah 2026-06-24 18:20:22 +02:00 committed by GitHub
parent fb401045cc
commit 894bb27af3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 8 deletions

View File

@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DbrxForCausalLM": "dbrx", "DbrxForCausalLM": "dbrx",
"DeciLMForCausalLM": "deci", "DeciLMForCausalLM": "deci",
"DeepseekForCausalLM": "deepseek", "DeepseekForCausalLM": "deepseek",
"DeepseekOCRForCausalLM": "deepseek",
"DeepseekV2ForCausalLM": "deepseek", "DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek", "DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek", "DeepseekV32ForCausalLM": "deepseek",
@ -233,6 +234,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"UMT5ForConditionalGeneration": "t5", "UMT5ForConditionalGeneration": "t5",
"UMT5Model": "t5", "UMT5Model": "t5",
"UltravoxModel": "ultravox", "UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VLlama3ForCausalLM": "llama", "VLlama3ForCausalLM": "llama",
"VoxtralForConditionalGeneration": "llama", "VoxtralForConditionalGeneration": "llama",
"WavTokenizerDec": "wavtokenizer", "WavTokenizerDec": "wavtokenizer",
@ -299,6 +301,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"StepVLForConditionalGeneration": "step3", "StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3", "Step3p7ForConditionalGeneration": "step3",
"UltravoxModel": "ultravox", "UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VoxtralForConditionalGeneration": "ultravox", "VoxtralForConditionalGeneration": "ultravox",
"YoutuVLForConditionalGeneration": "youtuvl", "YoutuVLForConditionalGeneration": "youtuvl",
} }

View File

@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel from .qwen import QwenModel
@ModelBase.register("DeepseekOCRForCausalLM") @ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel): class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
@ModelBase.register( @ModelBase.register(
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"UnlimitedOCRForCausalLM",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration", "KimiK25ForConditionalGeneration",
"YoutuForCausalLM", "YoutuForCausalLM",
@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
self.origin_hf_arch = hparams.get('architectures', [None])[0] self.origin_hf_arch = hparams.get('architectures', [None])[0]
# special handling for Deepseek OCR # special handling for Deepseek OCR
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"): if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture() self.gguf_writer.add_architecture()
@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
if is_ocr:
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None: if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul

View File

@ -9,6 +9,7 @@ its output, and holds them against the HF model's scores.
import argparse import argparse
import logging import logging
import re
import subprocess import subprocess
import sys import sys
import unicodedata import unicodedata
@ -28,6 +29,12 @@ class ModelSpec:
mmproj_arg: str mmproj_arg: str
model_default: str model_default: str
mmproj_default: str mmproj_default: str
prompt: str = "Free OCR. "
n_predict: int = 512
n_ctx: int | None = None
# Unlimited-OCR's "document parsing" prompt emits <|det|> grounding markup that
# the HF reference strips in result.md; drop it before scoring to match.
strip_grounding: bool = False
@dataclass @dataclass
@ -63,6 +70,20 @@ MODELS = {
model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf", model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf",
mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf", mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf",
), ),
"unlimited": ModelSpec(
key="unlimited", label="Unlimited-OCR",
model_arg="--llama-model-unlimited", mmproj_arg="--mmproj-unlimited",
model_default="gguf_models/baidu/unlimited-ocr-bf16.gguf",
mmproj_default="gguf_models/baidu/mmproj-unlimited-ocr-bf16.gguf",
# "Free OCR." immediately emits EOS on this checkpoint; the HF reference
# (demo/unlimited_ocr_scores.py) uses "document parsing.", which grounds.
prompt="document parsing.",
# Grounding emits ~3x the tokens of plain OCR, so it needs a larger budget
# and context to reach the article body the ground truth covers.
n_predict=4096,
n_ctx=16384,
strip_grounding=True,
),
} }
CASES = [ CASES = [
@ -82,9 +103,26 @@ CASES = [
# is one pixel off and lands at ~0.69 instead. # is one pixel off and lands at ~0.69 instead.
hf_cer=0.7761, hf_chrf=28.70, cer_tol=0.12, chrf_tol=8.0, hf_cer=0.7761, hf_chrf=28.70, cer_tol=0.12, chrf_tol=8.0,
), ),
TestCase(
model_key="unlimited", label="single-view scan",
image="tools/mtmd/test-1.jpeg",
ground_truth="tools/mtmd/tests/test-1-ground-truth.txt",
# HF reference: Unlimited-OCR scoring (gundam, bf16) on this image/ground-truth.
# Decoder runs full MHA, not R-SWA; the band absorbs that gap + bf16 variance.
hf_cer=0.1869, hf_chrf=75.23, cer_tol=0.06, chrf_tol=6.0,
),
] ]
GROUNDING_TAG_RE = re.compile(r"<\|(ref|det)\|>.*?<\|/\1\|>", re.DOTALL)
def strip_grounding(text: str) -> str:
"""Drop <|ref|>..<|/ref|> / <|det|>..<|/det|> grounding markup, matching the
cleaned result.md the HF reference scores against."""
return GROUNDING_TAG_RE.sub("", text)
def arg_dest(flag: str) -> str: def arg_dest(flag: str) -> str:
return flag.lstrip("-").replace("-", "_") return flag.lstrip("-").replace("-", "_")
@ -129,19 +167,19 @@ def compute_chrf(expected: str, ocr_out: str) -> float:
return CHRF().sentence_score(ocr_out, [expected]).score return CHRF().sentence_score(ocr_out, [expected]).score
def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str: def run_mtmd_cli(spec: "ModelSpec", model_path, mmproj_path, image_path, bin_path) -> str:
"""Run mtmd-cli on the image and return its output.""" """Run mtmd-cli on the image and return its output."""
cmd = [ cmd = [
str(bin_path), str(bin_path),
"-m", str(model_path), "-m", str(model_path),
"--mmproj", str(mmproj_path), "--mmproj", str(mmproj_path),
"--image", str(image_path), "--image", str(image_path),
"-p", "Free OCR. ", "-p", spec.prompt,
"--chat-template", "deepseek-ocr", "--chat-template", "deepseek-ocr",
"--temp", "0", "--temp", "0",
"--flash-attn", "off", # match the HF "eager" attention reference "--flash-attn", "off", # match the HF "eager" attention reference
"--no-warmup", "--no-warmup",
"-n", "512", # cap loops on hard images (KV would otherwise fill) "-n", str(spec.n_predict), # cap loops on hard images (KV would otherwise fill)
# HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY. # HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY.
# Default DRY breakers include "\n", so they are cleared below. # Default DRY breakers include "\n", so they are cleared below.
"--dry-multiplier", "0.8", "--dry-multiplier", "0.8",
@ -150,6 +188,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
"--dry-penalty-last-n", "-1", "--dry-penalty-last-n", "-1",
"--dry-sequence-breaker", "none", "--dry-sequence-breaker", "none",
] ]
if spec.n_ctx is not None:
cmd += ["-c", str(spec.n_ctx)]
logger.debug(f" command: {' '.join(cmd)}") logger.debug(f" command: {' '.join(cmd)}")
try: try:
@ -164,6 +204,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}") raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}")
output = result.stdout.decode("utf-8", errors="replace").strip() output = result.stdout.decode("utf-8", errors="replace").strip()
if spec.strip_grounding:
output = strip_grounding(output)
if not output: if not output:
raise RuntimeError("llama-mtmd-cli produced no output on stdout") raise RuntimeError("llama-mtmd-cli produced no output on stdout")
logger.info(f" output: {len(output)} chars") logger.info(f" output: {len(output)} chars")
@ -193,7 +235,7 @@ def evaluate(case: "TestCase", expected: str, ocr_out: str) -> bool:
logger.info("") logger.info("")
logger.info("=" * 60) logger.info("=" * 60)
logger.info("Free OCR evaluation:") logger.info("OCR evaluation:")
logger.info("=" * 60) logger.info("=" * 60)
logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})") logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})")
logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})") logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})")
@ -269,9 +311,9 @@ def main() -> int:
expected = read_expected_text(ground_truth) expected = read_expected_text(ground_truth)
logger.info(f" Image: {case.image}") logger.info(f" Image: {case.image}")
logger.info(f" Expected text: {len(expected)} chars") logger.info(f" Expected text: {len(expected)} chars")
logger.info(" Running llama.cpp 'Free OCR'") logger.info(f" Running llama.cpp prompt {model_spec.prompt!r}")
try: try:
ocr_out = run_mtmd_cli(model, mmproj, image, binary) ocr_out = run_mtmd_cli(model_spec, model, mmproj, image, binary)
except RuntimeError as e: except RuntimeError as e:
logger.error(f" Error: {e}") logger.error(f" Error: {e}")
results[title] = False results[title] = False