diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index e5764a04b..b6c0ca6ed 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -202,6 +202,8 @@ def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID def get_sampling_method(self, sample: dict) -> str: if len(sample.get("logprobs", [])) > 0: + if len(sample.get("text", [])) == 0: + return SamplingMethod.PERPLEXITY return SamplingMethod.LOGPROBS if len(sample.get("text", [])) > 0: return SamplingMethod.GENERATIVE @@ -261,7 +263,10 @@ def get_samples_to_process_and_cache( return docs_not_cached, set(tasks_with_cached_samples) def get_samples_from_cache( - self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod + self, + docs: List[Doc], + task_ids: List[TaskID] | set[TaskID], + sampling_method: SamplingMethod, ) -> List[dict | ModelResponse]: """Get cached samples for the given docs. Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise.