mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 23:50:20 -05:00
mtmd: model: unlimited-ocr: converter + parity test (#24969)
This commit is contained in:
parent
fb401045cc
commit
894bb27af3
@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
|||||||
"DbrxForCausalLM": "dbrx",
|
"DbrxForCausalLM": "dbrx",
|
||||||
"DeciLMForCausalLM": "deci",
|
"DeciLMForCausalLM": "deci",
|
||||||
"DeepseekForCausalLM": "deepseek",
|
"DeepseekForCausalLM": "deepseek",
|
||||||
|
"DeepseekOCRForCausalLM": "deepseek",
|
||||||
"DeepseekV2ForCausalLM": "deepseek",
|
"DeepseekV2ForCausalLM": "deepseek",
|
||||||
"DeepseekV3ForCausalLM": "deepseek",
|
"DeepseekV3ForCausalLM": "deepseek",
|
||||||
"DeepseekV32ForCausalLM": "deepseek",
|
"DeepseekV32ForCausalLM": "deepseek",
|
||||||
@ -233,6 +234,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
|||||||
"UMT5ForConditionalGeneration": "t5",
|
"UMT5ForConditionalGeneration": "t5",
|
||||||
"UMT5Model": "t5",
|
"UMT5Model": "t5",
|
||||||
"UltravoxModel": "ultravox",
|
"UltravoxModel": "ultravox",
|
||||||
|
"UnlimitedOCRForCausalLM": "deepseek",
|
||||||
"VLlama3ForCausalLM": "llama",
|
"VLlama3ForCausalLM": "llama",
|
||||||
"VoxtralForConditionalGeneration": "llama",
|
"VoxtralForConditionalGeneration": "llama",
|
||||||
"WavTokenizerDec": "wavtokenizer",
|
"WavTokenizerDec": "wavtokenizer",
|
||||||
@ -299,6 +301,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
|||||||
"StepVLForConditionalGeneration": "step3",
|
"StepVLForConditionalGeneration": "step3",
|
||||||
"Step3p7ForConditionalGeneration": "step3",
|
"Step3p7ForConditionalGeneration": "step3",
|
||||||
"UltravoxModel": "ultravox",
|
"UltravoxModel": "ultravox",
|
||||||
|
"UnlimitedOCRForCausalLM": "deepseek",
|
||||||
"VoxtralForConditionalGeneration": "ultravox",
|
"VoxtralForConditionalGeneration": "ultravox",
|
||||||
"YoutuVLForConditionalGeneration": "youtuvl",
|
"YoutuVLForConditionalGeneration": "youtuvl",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
|
|||||||
from .qwen import QwenModel
|
from .qwen import QwenModel
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
|
||||||
class DeepseekOCRVisionModel(MmprojModel):
|
class DeepseekOCRVisionModel(MmprojModel):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
|
|||||||
@ModelBase.register(
|
@ModelBase.register(
|
||||||
"DeepseekV2ForCausalLM",
|
"DeepseekV2ForCausalLM",
|
||||||
"DeepseekV3ForCausalLM",
|
"DeepseekV3ForCausalLM",
|
||||||
|
"DeepseekOCRForCausalLM",
|
||||||
|
"UnlimitedOCRForCausalLM",
|
||||||
"KimiVLForConditionalGeneration",
|
"KimiVLForConditionalGeneration",
|
||||||
"KimiK25ForConditionalGeneration",
|
"KimiK25ForConditionalGeneration",
|
||||||
"YoutuForCausalLM",
|
"YoutuForCausalLM",
|
||||||
@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
|
|||||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||||
|
|
||||||
# special handling for Deepseek OCR
|
# special handling for Deepseek OCR
|
||||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
|
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
|
||||||
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
|
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
|
||||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||||
self.gguf_writer.add_architecture()
|
self.gguf_writer.add_architecture()
|
||||||
@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
|
|||||||
|
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||||
|
|
||||||
|
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
|
||||||
|
if is_ocr:
|
||||||
|
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
|
||||||
|
if sliding_window:
|
||||||
|
self.gguf_writer.add_sliding_window(sliding_window)
|
||||||
|
|
||||||
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
|
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
|
||||||
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
||||||
|
|||||||
@ -9,6 +9,7 @@ its output, and holds them against the HF model's scores.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import unicodedata
|
import unicodedata
|
||||||
@ -28,6 +29,12 @@ class ModelSpec:
|
|||||||
mmproj_arg: str
|
mmproj_arg: str
|
||||||
model_default: str
|
model_default: str
|
||||||
mmproj_default: str
|
mmproj_default: str
|
||||||
|
prompt: str = "Free OCR. "
|
||||||
|
n_predict: int = 512
|
||||||
|
n_ctx: int | None = None
|
||||||
|
# Unlimited-OCR's "document parsing" prompt emits <|det|> grounding markup that
|
||||||
|
# the HF reference strips in result.md; drop it before scoring to match.
|
||||||
|
strip_grounding: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -63,6 +70,20 @@ MODELS = {
|
|||||||
model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf",
|
model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf",
|
||||||
mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf",
|
mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf",
|
||||||
),
|
),
|
||||||
|
"unlimited": ModelSpec(
|
||||||
|
key="unlimited", label="Unlimited-OCR",
|
||||||
|
model_arg="--llama-model-unlimited", mmproj_arg="--mmproj-unlimited",
|
||||||
|
model_default="gguf_models/baidu/unlimited-ocr-bf16.gguf",
|
||||||
|
mmproj_default="gguf_models/baidu/mmproj-unlimited-ocr-bf16.gguf",
|
||||||
|
# "Free OCR." immediately emits EOS on this checkpoint; the HF reference
|
||||||
|
# (demo/unlimited_ocr_scores.py) uses "document parsing.", which grounds.
|
||||||
|
prompt="document parsing.",
|
||||||
|
# Grounding emits ~3x the tokens of plain OCR, so it needs a larger budget
|
||||||
|
# and context to reach the article body the ground truth covers.
|
||||||
|
n_predict=4096,
|
||||||
|
n_ctx=16384,
|
||||||
|
strip_grounding=True,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
CASES = [
|
CASES = [
|
||||||
@ -82,9 +103,26 @@ CASES = [
|
|||||||
# is one pixel off and lands at ~0.69 instead.
|
# is one pixel off and lands at ~0.69 instead.
|
||||||
hf_cer=0.7761, hf_chrf=28.70, cer_tol=0.12, chrf_tol=8.0,
|
hf_cer=0.7761, hf_chrf=28.70, cer_tol=0.12, chrf_tol=8.0,
|
||||||
),
|
),
|
||||||
|
TestCase(
|
||||||
|
model_key="unlimited", label="single-view scan",
|
||||||
|
image="tools/mtmd/test-1.jpeg",
|
||||||
|
ground_truth="tools/mtmd/tests/test-1-ground-truth.txt",
|
||||||
|
# HF reference: Unlimited-OCR scoring (gundam, bf16) on this image/ground-truth.
|
||||||
|
# Decoder runs full MHA, not R-SWA; the band absorbs that gap + bf16 variance.
|
||||||
|
hf_cer=0.1869, hf_chrf=75.23, cer_tol=0.06, chrf_tol=6.0,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
GROUNDING_TAG_RE = re.compile(r"<\|(ref|det)\|>.*?<\|/\1\|>", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_grounding(text: str) -> str:
|
||||||
|
"""Drop <|ref|>..<|/ref|> / <|det|>..<|/det|> grounding markup, matching the
|
||||||
|
cleaned result.md the HF reference scores against."""
|
||||||
|
return GROUNDING_TAG_RE.sub("", text)
|
||||||
|
|
||||||
|
|
||||||
def arg_dest(flag: str) -> str:
|
def arg_dest(flag: str) -> str:
|
||||||
return flag.lstrip("-").replace("-", "_")
|
return flag.lstrip("-").replace("-", "_")
|
||||||
|
|
||||||
@ -129,19 +167,19 @@ def compute_chrf(expected: str, ocr_out: str) -> float:
|
|||||||
return CHRF().sentence_score(ocr_out, [expected]).score
|
return CHRF().sentence_score(ocr_out, [expected]).score
|
||||||
|
|
||||||
|
|
||||||
def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
|
def run_mtmd_cli(spec: "ModelSpec", model_path, mmproj_path, image_path, bin_path) -> str:
|
||||||
"""Run mtmd-cli on the image and return its output."""
|
"""Run mtmd-cli on the image and return its output."""
|
||||||
cmd = [
|
cmd = [
|
||||||
str(bin_path),
|
str(bin_path),
|
||||||
"-m", str(model_path),
|
"-m", str(model_path),
|
||||||
"--mmproj", str(mmproj_path),
|
"--mmproj", str(mmproj_path),
|
||||||
"--image", str(image_path),
|
"--image", str(image_path),
|
||||||
"-p", "Free OCR. ",
|
"-p", spec.prompt,
|
||||||
"--chat-template", "deepseek-ocr",
|
"--chat-template", "deepseek-ocr",
|
||||||
"--temp", "0",
|
"--temp", "0",
|
||||||
"--flash-attn", "off", # match the HF "eager" attention reference
|
"--flash-attn", "off", # match the HF "eager" attention reference
|
||||||
"--no-warmup",
|
"--no-warmup",
|
||||||
"-n", "512", # cap loops on hard images (KV would otherwise fill)
|
"-n", str(spec.n_predict), # cap loops on hard images (KV would otherwise fill)
|
||||||
# HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY.
|
# HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY.
|
||||||
# Default DRY breakers include "\n", so they are cleared below.
|
# Default DRY breakers include "\n", so they are cleared below.
|
||||||
"--dry-multiplier", "0.8",
|
"--dry-multiplier", "0.8",
|
||||||
@ -150,6 +188,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
|
|||||||
"--dry-penalty-last-n", "-1",
|
"--dry-penalty-last-n", "-1",
|
||||||
"--dry-sequence-breaker", "none",
|
"--dry-sequence-breaker", "none",
|
||||||
]
|
]
|
||||||
|
if spec.n_ctx is not None:
|
||||||
|
cmd += ["-c", str(spec.n_ctx)]
|
||||||
logger.debug(f" command: {' '.join(cmd)}")
|
logger.debug(f" command: {' '.join(cmd)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -164,6 +204,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str:
|
|||||||
raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}")
|
raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}")
|
||||||
|
|
||||||
output = result.stdout.decode("utf-8", errors="replace").strip()
|
output = result.stdout.decode("utf-8", errors="replace").strip()
|
||||||
|
if spec.strip_grounding:
|
||||||
|
output = strip_grounding(output)
|
||||||
if not output:
|
if not output:
|
||||||
raise RuntimeError("llama-mtmd-cli produced no output on stdout")
|
raise RuntimeError("llama-mtmd-cli produced no output on stdout")
|
||||||
logger.info(f" output: {len(output)} chars")
|
logger.info(f" output: {len(output)} chars")
|
||||||
@ -193,7 +235,7 @@ def evaluate(case: "TestCase", expected: str, ocr_out: str) -> bool:
|
|||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info("Free OCR evaluation:")
|
logger.info("OCR evaluation:")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})")
|
logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})")
|
||||||
logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})")
|
logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})")
|
||||||
@ -269,9 +311,9 @@ def main() -> int:
|
|||||||
expected = read_expected_text(ground_truth)
|
expected = read_expected_text(ground_truth)
|
||||||
logger.info(f" Image: {case.image}")
|
logger.info(f" Image: {case.image}")
|
||||||
logger.info(f" Expected text: {len(expected)} chars")
|
logger.info(f" Expected text: {len(expected)} chars")
|
||||||
logger.info(" Running llama.cpp 'Free OCR'")
|
logger.info(f" Running llama.cpp prompt {model_spec.prompt!r}")
|
||||||
try:
|
try:
|
||||||
ocr_out = run_mtmd_cli(model, mmproj, image, binary)
|
ocr_out = run_mtmd_cli(model_spec, model, mmproj, image, binary)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f" Error: {e}")
|
logger.error(f" Error: {e}")
|
||||||
results[title] = False
|
results[title] = False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user