diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 997034b..e4c0258 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -106,11 +106,14 @@ jobs: # Wait for server to be ready echo "Waiting for server to start..." sleep 15 - + # Test server health curl -s http://localhost:8000/health || echo "Server health check failed" env: OPTILLM_API_KEY: optillm + # Bound generation for the small test model, which does not reliably emit + # an EOS token and would otherwise ramble up to the 4096 default per call. + OPTILLM_MAX_TOKENS: "128" HF_TOKEN: ${{ secrets.HF_TOKEN }} - name: Run integration tests (server required) @@ -217,6 +220,9 @@ jobs: curl -s http://localhost:8000/health || echo "Server health check failed" env: OPTILLM_API_KEY: optillm + # Bound generation for the small test model, which does not reliably emit + # an EOS token and would otherwise ramble up to the 4096 default per call. + OPTILLM_MAX_TOKENS: "128" HF_TOKEN: ${{ secrets.HF_TOKEN }} - name: Run conversation logging tests diff --git a/optillm/__init__.py b/optillm/__init__.py index e6c30c8..aea5fc3 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,5 +1,5 @@ # Version information -__version__ = "0.3.19" +__version__ = "0.3.20" import os as _os diff --git a/optillm/inference.py b/optillm/inference.py index 3b61a87..de4ebbb 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -3,7 +3,7 @@ import logging import numpy as np from typing import Dict, List, Optional, Tuple, Any, Union -from dataclasses import dataclass +from dataclasses import dataclass, field from collections import OrderedDict, defaultdict import torch.nn.functional as F import torch.nn as nn @@ -89,6 +89,27 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int: MLX_AVAILABLE = False logger.debug("MLX framework not available - falling back to PyTorch") + +# Hard ceiling of 4096 by default. Can be lowered via OPTILLM_MAX_TOKENS so a +# single local generation is bounded even when the request (or an approach's +# internal calls) sends no max_tokens -- important for small local models that +# do not reliably emit an EOS token (e.g. the dhara test model), which would +# otherwise ramble up to the full default on every call. +DEFAULT_MAX_NEW_TOKENS = 4096 + + +def _default_max_new_tokens() -> int: + """Default ``max_new_tokens`` for local generation (env-overridable).""" + raw = os.environ.get("OPTILLM_MAX_TOKENS") + if raw is None: + return DEFAULT_MAX_NEW_TOKENS + try: + return max(1, int(raw)) + except (TypeError, ValueError): + logger.warning("Ignoring invalid OPTILLM_MAX_TOKENS=%r; using %d", raw, DEFAULT_MAX_NEW_TOKENS) + return DEFAULT_MAX_NEW_TOKENS + + @dataclass class ModelConfig: base_model_id: str @@ -98,7 +119,7 @@ class ModelConfig: quantization_bits: int = 4 device_preference: Optional[str] = None # Default generation parameters - max_new_tokens: int = 4096 + max_new_tokens: int = field(default_factory=_default_max_new_tokens) do_sample: bool = True top_p: float = 0.9 top_k: int = 50 @@ -292,7 +313,7 @@ def suggest_mlx_alternative(model_id: str) -> str: class MLXModelConfig: """Configuration for MLX models""" model_id: str - max_new_tokens: int = 4096 + max_new_tokens: int = field(default_factory=_default_max_new_tokens) temperature: float = 0.7 top_p: float = 0.9 repetition_penalty: float = 1.0 @@ -1268,16 +1289,46 @@ def setup_tokenizer(self, tokenizer: AutoTokenizer) -> AutoTokenizer: return tokenizer + def _resolve_eos_token_ids(self): + """Resolve the effective end-of-sequence token id(s) for generation. + + Prefer the model's own ``generation_config.eos_token_id``. Chat models + commonly set it to the chat-turn end token (e.g. ``<|im_end|>``), which + can differ from the tokenizer's ``eos_token_id`` (often the base-model + ``<|end_of_text|>``). Passing only the tokenizer eos to ``generate`` there + means the model never stops on its real turn-end token and rambles up to + ``max_new_tokens`` -- e.g. dhara-250m's ChatML ends at ``<|im_end|>`` but + its tokenizer eos is ``<|end_of_text|>``. + + The tokenizer eos is merged in as a fallback so a model that only emits + the base eos still terminates. Returns an int, a list of ints, or None. + """ + ids: List[int] = [] + gen_cfg = getattr(self.current_model, "generation_config", None) + gc_eos = getattr(gen_cfg, "eos_token_id", None) if gen_cfg is not None else None + if isinstance(gc_eos, int): + ids.append(gc_eos) + elif isinstance(gc_eos, (list, tuple)): + ids.extend(int(x) for x in gc_eos if isinstance(x, int)) + tok_eos = self.tokenizer.eos_token_id + if isinstance(tok_eos, int): + ids.append(tok_eos) + seen = set() + resolved = [x for x in ids if not (x in seen or seen.add(x))] + if not resolved: + return None + return resolved[0] if len(resolved) == 1 else resolved + def get_optimized_generation_config(self, generation_params: Optional[Dict[str, Any]] = None) -> Dict: """Get optimized generation config""" config = { - "max_new_tokens": generation_params.get("max_new_tokens", 4096), + "max_new_tokens": generation_params.get("max_new_tokens", _default_max_new_tokens()), "do_sample": generation_params.get("temperature", 1.0) > 0, "temperature": generation_params.get("temperature", 1.0), "top_p": generation_params.get("top_p", 0.95), "num_return_sequences": generation_params.get("num_return_sequences", 1), "pad_token_id": self.tokenizer.pad_token_id, - "eos_token_id": self.tokenizer.eos_token_id, + "eos_token_id": self._resolve_eos_token_ids(), "return_dict_in_generate": True, "output_scores": generation_params.get("logprobs", False), "use_cache": True @@ -1571,13 +1622,13 @@ def process_batch( if batch_prompts: # If there are any uncached prompts # Configure generation parameters base_params = { - "max_new_tokens": generation_params.get("max_new_tokens", 4096) if generation_params else self.model_config.max_new_tokens, + "max_new_tokens": generation_params.get("max_new_tokens", _default_max_new_tokens()) if generation_params else self.model_config.max_new_tokens, "do_sample": generation_params.get("temperature", 1.0) > 0 if generation_params else self.model_config.do_sample, "temperature": generation_params.get("temperature", 1.0) if generation_params else self.model_config.temperature, "top_p": generation_params.get("top_p", 1.0) if generation_params else self.model_config.top_p, "num_return_sequences": n, "pad_token_id": self.tokenizer.pad_token_id, - "eos_token_id": self.tokenizer.eos_token_id, + "eos_token_id": self._resolve_eos_token_ids(), } # Add optional parameters if specified @@ -1900,7 +1951,7 @@ def create( # Use directly available parameters for entropy decoding entropy_params = { - "max_new_tokens": max_tokens if max_tokens is not None else 4096, + "max_new_tokens": max_tokens if max_tokens is not None else _default_max_new_tokens(), "temperature": temperature, "top_p": top_p, "top_k": top_k, @@ -2046,7 +2097,7 @@ def create( "temperature": temperature, "top_p": top_p, "num_return_sequences": n, - "max_new_tokens": max_tokens if max_tokens is not None else 4096, + "max_new_tokens": max_tokens if max_tokens is not None else _default_max_new_tokens(), "presence_penalty": presence_penalty, "frequency_penalty": frequency_penalty, "stop_sequences": [stop] if isinstance(stop, str) else stop, diff --git a/pyproject.toml b/pyproject.toml index 6ac8d1b..d534dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.3.19" +version = "0.3.20" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0" diff --git a/tests/test_batching.py b/tests/test_batching.py index 1896fcb..0cf6568 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -421,6 +421,70 @@ def run_performance_comparison(): } +class TestGenerationConfigDefaults(unittest.TestCase): + """Unit tests for env-configurable max_new_tokens and EOS resolution. + + These exercise the guards that keep a small local model from rambling up to + the 4096-token default when it does not reliably emit an EOS token (e.g. a + ChatML model whose tokenizer EOS differs from the chat-turn end token). No + model is loaded. + """ + + def tearDown(self): + os.environ.pop("OPTILLM_MAX_TOKENS", None) + + def test_default_max_new_tokens_env_override(self): + from optillm.inference import _default_max_new_tokens, DEFAULT_MAX_NEW_TOKENS + + os.environ.pop("OPTILLM_MAX_TOKENS", None) + self.assertEqual(_default_max_new_tokens(), DEFAULT_MAX_NEW_TOKENS) + + os.environ["OPTILLM_MAX_TOKENS"] = "128" + self.assertEqual(_default_max_new_tokens(), 128) + + # Invalid value falls back to the default rather than raising. + os.environ["OPTILLM_MAX_TOKENS"] = "not-a-number" + self.assertEqual(_default_max_new_tokens(), DEFAULT_MAX_NEW_TOKENS) + + # Non-positive is clamped to a usable minimum. + os.environ["OPTILLM_MAX_TOKENS"] = "0" + self.assertEqual(_default_max_new_tokens(), 1) + + def test_resolve_eos_prefers_generation_config(self): + from types import SimpleNamespace + from optillm.inference import InferencePipeline + + # tokenizer EOS (<|end_of_text|>=1) differs from the chat end token + # (<|im_end|>=49154); both must be honoured, generation_config first. + fake = SimpleNamespace( + current_model=SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=49154)), + tokenizer=SimpleNamespace(eos_token_id=1), + ) + eos = InferencePipeline._resolve_eos_token_ids(fake) + self.assertEqual(eos, [49154, 1]) + + def test_resolve_eos_dedupes_list(self): + from types import SimpleNamespace + from optillm.inference import InferencePipeline + + fake = SimpleNamespace( + current_model=SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=[100, 200])), + tokenizer=SimpleNamespace(eos_token_id=200), + ) + self.assertEqual(InferencePipeline._resolve_eos_token_ids(fake), [100, 200]) + + def test_resolve_eos_falls_back_to_tokenizer(self): + from types import SimpleNamespace + from optillm.inference import InferencePipeline + + fake = SimpleNamespace( + current_model=SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=None)), + tokenizer=SimpleNamespace(eos_token_id=7), + ) + # A single id is returned as a plain int, not a list. + self.assertEqual(InferencePipeline._resolve_eos_token_ids(fake), 7) + + if __name__ == "__main__": # Run tests unittest.main(verbosity=2) \ No newline at end of file