Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/using-the-python-api.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import lighteval
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.vllm.vllm_model import VLLMModelConfig
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.utils.imports import is_accelerate_available
from lighteval.utils.imports import is_package_available

if is_accelerate_available():
if is_package_available("accelerate"):
from datetime import timedelta
from accelerate import Accelerator, InitProcessGroupKwargs
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
Expand Down
10 changes: 5 additions & 5 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
TaskConfigLogger,
VersionsLogger,
)
from lighteval.utils.imports import NO_TENSORBOARDX_WARN_MSG, is_nanotron_available, is_tensorboardX_available
from lighteval.utils.imports import is_package_available, not_installed_error_message
from lighteval.utils.utils import obj_to_markdown


logger = logging.getLogger(__name__)

if is_nanotron_available():
if is_package_available("nanotron"):
from nanotron.config import GeneralArgs # type: ignore

try:
Expand Down Expand Up @@ -659,11 +659,11 @@ def recreate_metadata_card(self, repo_id: str) -> None: # noqa: C901
def push_to_tensorboard( # noqa: C901
self, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
):
if not is_tensorboardX_available:
logger.warning(NO_TENSORBOARDX_WARN_MSG)
if not is_package_available("tensorboardX"):
logger.warning(not_installed_error_message("tensorboardX"))
return

if not is_nanotron_available():
if not is_package_available("nanotron"):
logger.warning("You cannot push results to tensorboard without having nanotron installed. Skipping")
return

Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.imports import is_nanotron_available
from lighteval.utils.imports import is_package_available


logger = logging.getLogger(__name__)


if is_nanotron_available():
if is_package_available("nanotron"):
pass


Expand Down
11 changes: 5 additions & 6 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
reasoning_tags,
remove_reasoning_tags,
)
from lighteval.utils.imports import requires


SEED = 1234


@requires("nanotron")
def nanotron(
checkpoint_config_path: Annotated[
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
Expand All @@ -45,12 +47,9 @@ def nanotron(
remove_reasoning_tags: remove_reasoning_tags.type = remove_reasoning_tags.default,
reasoning_tags: reasoning_tags.type = reasoning_tags.default,
):
"""Evaluate models using nanotron as backend."""
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available

if not is_nanotron_available():
raise ImportError(NO_NANOTRON_ERROR_MSG)

"""
Evaluate models using nanotron as backend.
"""
from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs, get_config_from_dict, get_config_from_file

from lighteval.logging.evaluation_tracker import EvaluationTracker
Expand Down
5 changes: 2 additions & 3 deletions src/lighteval/metrics/imports/data_stats_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from multiprocessing import Pool

from lighteval.metrics.imports.data_stats_utils import Fragments
from lighteval.utils.imports import NO_SPACY_ERROR_MSG, is_spacy_available
from lighteval.utils.imports import raise_if_package_not_available


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,8 +70,7 @@ def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True):
tokenize (bool): whether to tokenize the input; otherwise assumes that the input
is a string of space-separated tokens.
"""
if not is_spacy_available():
raise ImportError(NO_SPACY_ERROR_MSG)
raise_if_package_not_available("spacy")
import spacy

self.n_gram = n_gram
Expand Down
8 changes: 4 additions & 4 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
from lighteval.tasks.requests import Doc
from lighteval.tasks.templates.utils.formulation import ChoicePrefix, get_prefix
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
from lighteval.utils.imports import requires_latex2sympy2_extended
from lighteval.utils.imports import requires
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout


@requires_latex2sympy2_extended
@requires("latex2sympy2_extended")
def latex_normalization_config_default_factory():
from latex2sympy2_extended.latex2sympy2 import NormalizationConfig

Expand Down Expand Up @@ -373,7 +373,7 @@ def get_target_type_order(target_type: ExtractionTarget) -> int:

# Small cache, to catche repeated calls invalid parsing
@lru_cache(maxsize=20)
@requires_latex2sympy2_extended
@requires("latex2sympy2_extended")
def parse_latex_with_timeout(latex: str, timeout_seconds: int):
from latex2sympy2_extended.latex2sympy2 import latex2sympy

Expand Down Expand Up @@ -428,7 +428,7 @@ def convert_to_pct(number: Number):
return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False)


@requires_latex2sympy2_extended
@requires("latex2sympy2_extended")
@lru_cache(maxsize=20)
def extract_latex(
match: re.Match, latex_config: LatexExtractionConfig, timeout_seconds: int
Expand Down
13 changes: 5 additions & 8 deletions src/lighteval/metrics/utils/linguistic_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
from typing import Callable, Iterator

from lighteval.utils.imports import (
NO_SPACY_TOKENIZER_ERROR_MSG,
NO_STANZA_TOKENIZER_ERROR_MSG,
can_load_spacy_tokenizer,
can_load_stanza_tokenizer,
Extras,
raise_if_package_not_available,
)
from lighteval.utils.language import Language

Expand Down Expand Up @@ -102,8 +100,8 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
class SpaCyTokenizer(WordTokenizer):
def __init__(self, spacy_language: str, config=None):
super().__init__()
if not can_load_spacy_tokenizer(spacy_language):
raise ImportError(NO_SPACY_TOKENIZER_ERROR_MSG)
raise_if_package_not_available(Extras.MULTILINGUAL, language=spacy_language)

self.spacy_language = spacy_language
self.config = config
self._tokenizer = None
Expand Down Expand Up @@ -140,8 +138,7 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
class StanzaTokenizer(WordTokenizer):
def __init__(self, stanza_language: str, **stanza_kwargs):
super().__init__()
if not can_load_stanza_tokenizer():
raise ImportError(NO_STANZA_TOKENIZER_ERROR_MSG)
raise_if_package_not_available("stanza")
self.stanza_language = stanza_language
self.stanza_kwargs = stanza_kwargs
self._tokenizer = None
Expand Down
11 changes: 4 additions & 7 deletions src/lighteval/metrics/utils/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
from lighteval.utils.imports import raise_if_package_not_available
from lighteval.utils.utils import as_list


Expand Down Expand Up @@ -131,8 +131,7 @@ def __lazy_load_client(self): # noqa: C901
# Both "openai" and "tgi" backends use the OpenAI-compatible API
# They are handled separately to allow for backend-specific validation and setup
case "openai" | "tgi":
if not is_openai_available():
raise RuntimeError("OpenAI backend is not available.")
raise_if_package_not_available("openai")
if self.client is None:
from openai import OpenAI

Expand All @@ -142,13 +141,11 @@ def __lazy_load_client(self): # noqa: C901
return self.__call_api_parallel

case "litellm":
if not is_litellm_available():
raise RuntimeError("litellm is not available.")
raise_if_package_not_available("litellm")
return self.__call_litellm

case "vllm":
if not is_vllm_available():
raise RuntimeError("vllm is not available.")
raise_if_package_not_available("vllm")
if self.pipe is None:
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/metrics/utils/math_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from sympy.core.function import UndefinedFunction
from sympy.core.relational import Relational

from lighteval.utils.imports import requires_latex2sympy2_extended
from lighteval.utils.imports import requires
from lighteval.utils.timeout import timeout


Expand Down Expand Up @@ -308,7 +308,7 @@ def is_equation(expr: Basic | MatrixBase) -> bool:
return False


@requires_latex2sympy2_extended
@requires("latex2sympy2_extended")
def is_assignment_relation(expr: Basic | MatrixBase) -> bool:
from latex2sympy2_extended.latex2sympy2 import is_expr_of_only_symbols

Expand Down
5 changes: 3 additions & 2 deletions src/lighteval/models/endpoints/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.imports import is_litellm_available
from lighteval.utils.imports import is_package_available, requires


logger = logging.getLogger(__name__)

if is_litellm_available():
if is_package_available("litellm"):
import litellm
from litellm import encode
from litellm.caching.caching import Cache
Expand Down Expand Up @@ -110,6 +110,7 @@ class LiteLLMModelConfig(ModelConfig):
concurrent_requests: int = 10


@requires("litellm")
class LiteLLMClient(LightevalModel):
_DEFAULT_MAX_LENGTH: int = 4096

Expand Down
7 changes: 3 additions & 4 deletions src/lighteval/models/endpoints/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.utils.cache_management import SampleCache
from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available
from lighteval.utils.imports import is_package_available, requires


if is_tgi_available():
if is_package_available("tgi"):
from text_generation import AsyncClient
else:
from unittest.mock import Mock
Expand Down Expand Up @@ -99,12 +99,11 @@ class TGIModelConfig(ModelConfig):

# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
# the client functions, since they use a different client.
@requires("tgi")
class ModelClient(InferenceEndpointModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, config: TGIModelConfig) -> None:
if not is_tgi_available():
raise ImportError(NO_TGI_ERROR_MSG)
headers = (
{} if config.inference_server_auth is None else {"Authorization": f"Bearer {config.inference_server_auth}"}
)
Expand Down
26 changes: 5 additions & 21 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,7 @@
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModel, VLMTransformersModelConfig
from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig
from lighteval.utils.imports import (
NO_LITELLM_ERROR_MSG,
NO_SGLANG_ERROR_MSG,
NO_TGI_ERROR_MSG,
NO_VLLM_ERROR_MSG,
is_litellm_available,
is_sglang_available,
is_tgi_available,
is_vllm_available,
)
from lighteval.utils.imports import raise_if_package_not_available, requires


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,19 +92,15 @@ def load_model( # noqa: C901
return load_inference_providers_model(config=config)


@requires("tgi")
def load_model_with_tgi(config: TGIModelConfig):
if not is_tgi_available():
raise ImportError(NO_TGI_ERROR_MSG)

logger.info(f"Load model from inference server: {config.inference_server_address}")
model = ModelClient(config=config)
return model


@requires("litellm")
def load_litellm_model(config: LiteLLMModelConfig):
if not is_litellm_available():
raise ImportError(NO_LITELLM_ERROR_MSG)

model = LiteLLMClient(config)
return model

Expand Down Expand Up @@ -163,8 +150,7 @@ def load_model_with_accelerate_or_default(
elif isinstance(config, DeltaModelConfig):
model = DeltaModel(config=config)
elif isinstance(config, VLLMModelConfig):
if not is_vllm_available():
raise ImportError(NO_VLLM_ERROR_MSG)
raise_if_package_not_available("vllm")
if config.is_async:
model = AsyncVLLMModel(config=config)
else:
Expand All @@ -185,8 +171,6 @@ def load_inference_providers_model(config: InferenceProvidersModelConfig):
return InferenceProvidersClient(config=config)


@requires("sglang")
def load_sglang_model(config: SGLangModelConfig):
if not is_sglang_available():
raise ImportError(NO_SGLANG_ERROR_MSG)

return SGLangModel(config=config)
4 changes: 2 additions & 2 deletions src/lighteval/models/nanotron/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Doc,
)
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.imports import is_nanotron_available
from lighteval.utils.imports import is_package_available
from lighteval.utils.parallelism import find_executable_batch_size
from lighteval.utils.utils import as_list

Expand All @@ -62,7 +62,7 @@

TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]

if is_nanotron_available():
if is_package_available("nanotron"):
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import GeneralArgs, ModelArgs, TokenizerArgs
Expand Down
5 changes: 3 additions & 2 deletions src/lighteval/models/sglang/sglang_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc
from lighteval.utils.cache_management import SampleCache, cached
from lighteval.utils.imports import is_sglang_available
from lighteval.utils.imports import is_package_available, requires


logger = logging.getLogger(__name__)

if is_sglang_available():
if is_package_available("sglang"):
from sglang import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer

Expand Down Expand Up @@ -138,6 +138,7 @@ class SGLangModelConfig(ModelConfig):
override_chat_template: bool = None


@requires("sglang")
class SGLangModel(LightevalModel):
def __init__(
self,
Expand Down
Loading
Loading