From 9b2ca83bef0147e0f0e061d9c465eb5c37039bce Mon Sep 17 00:00:00 2001 From: Duc Hoang Date: Tue, 4 Nov 2025 11:20:57 -0800 Subject: [PATCH 1/2] fix perplexity metric issues --- src/lighteval/utils/cache_management.py | 107 ++++++++++++++++++------ 1 file changed, 82 insertions(+), 25 deletions(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index e5764a04b..75ba2c50c 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -39,7 +39,6 @@ from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.utils import as_list - logger = logging.getLogger(__name__) @@ -58,7 +57,9 @@ def __str__(self): return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})" def __hash__(self): - return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big") + return int.from_bytes( + hashlib.sha256(str(self).encode()).digest(), byteorder="big" + ) class SampleCache: @@ -84,7 +85,9 @@ def __init__(self, model_config: ModelConfig): self.model_hash = self.get_model_hash(model_config) self.cache_dir = ( - Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash + Path(os.path.expanduser(self.model_config.cache_dir)) + / self.model_config.model_name + / self.model_hash ) self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -115,10 +118,14 @@ def _load_cached_indices(self) -> dict: # cache_file.parts gives all the subfolders of the url, up to the file name # last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2 task_name, task_hash = cache_file.parts[-3:-1] - sampling_method = SamplingMethod[cache_file.stem] # removes the file extension + sampling_method = SamplingMethod[ + cache_file.stem + ] # removes the file extension task_id = TaskID(task_name, task_hash, sampling_method) - full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + full_dataset = load_dataset( + "parquet", data_files=str(cache_file), split="train" + ) sample_ids = [] for row in full_dataset: try: @@ -169,7 +176,9 @@ def _get_task_hash(self, full_task_name: str) -> str: task_configs: list[LightevalTaskConfig] = sorted( self.registry.task_to_configs[f"{task_suite}|{task_name}"] ) - config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs]) + config_str = "|".join( + [task_config.__str__(lite=True) for task_config in task_configs] + ) task_hash = hashlib.sha256(config_str.encode()).hexdigest()[:16] self._task_hashes[full_task_name] = task_hash return self._task_hashes[full_task_name] @@ -183,7 +192,12 @@ def get_cache_path(self, task_id: TaskID) -> Path: Returns: Path: Path to the cache file for the given task and sample type """ - return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" + return ( + self.cache_dir + / task_id.task_name + / task_id.task_hash + / f"{task_id.sampling_method.name}.parquet" + ) def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID: """Returns a unique task indentifier. Depends on the task name, @@ -202,12 +216,16 @@ def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID def get_sampling_method(self, sample: dict) -> str: if len(sample.get("logprobs", [])) > 0: + if len(sample.get("text", [])) == 0: + return SamplingMethod.PERPLEXITY return SamplingMethod.LOGPROBS if len(sample.get("text", [])) > 0: return SamplingMethod.GENERATIVE return None - def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]: + def _load_sample( + self, sample: pd.core.series.Series | dict + ) -> Union[dict, ModelResponse]: """Load a sample from cached data based on sample type. Args: @@ -261,7 +279,10 @@ def get_samples_to_process_and_cache( return docs_not_cached, set(tasks_with_cached_samples) def get_samples_from_cache( - self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod + self, + docs: List[Doc], + task_ids: List[TaskID] | set[TaskID], + sampling_method: SamplingMethod, ) -> List[dict | ModelResponse]: """Get cached samples for the given docs. Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise. @@ -277,11 +298,15 @@ def get_samples_from_cache( continue cache_file = self.get_cache_path(task_id) try: - dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + dataset = load_dataset( + "parquet", data_files=str(cache_file), split="train" + ) dataset_df = dataset.to_pandas().set_index("sample_id") task_datasets[task_id] = dataset_df except Exception as e: - logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}") + logger.warning( + f"Error loading prediction cache for {str(task_id)}: {e}" + ) # Build results list results = [] @@ -311,7 +336,11 @@ def cache_samples( # noqa C901 sample = self._dump_sample(result) processed_data[task_id].append({"sample_id": doc.id, "sample": sample}) - processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} + processed_data = { + task_id: task_data + for task_id, task_data in processed_data.items() + if task_data + } # Concatenate it with existing data and save to file for task_id, task_data in processed_data.items(): @@ -325,32 +354,49 @@ def cache_samples( # noqa C901 existing_samples = {} if cache_file.exists(): try: - existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") + existing_dataset = load_dataset( + "parquet", data_files=str(cache_file), split="train" + ) existing_data = existing_dataset.to_list() except KeyError: logger.info(f"No data was cached for {str(task_id)}") except Exception as e: - logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}") + logger.error( + f"Error loading existing prediction cache for {str(task_id)}: {e}" + ) - existing_samples = {(row["sample_id"], sampling_method) for row in existing_data} - if any((row["sample_id"], sampling_method) in existing_samples for row in task_data): + existing_samples = { + (row["sample_id"], sampling_method) for row in existing_data + } + if any( + (row["sample_id"], sampling_method) in existing_samples + for row in task_data + ): logger.warning( "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." ) # Merge with new data (new data overwrites existing) # We look at id + sampling method - new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples] + new_data = [ + row + for row in task_data + if (row["sample_id"], sampling_method) not in existing_samples + ] all_samples = existing_data + new_data # Save updated dataset dataset = Dataset.from_list(all_samples) dataset.to_parquet(str(cache_file)) - logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.") + logger.info( + f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}." + ) # Refresh cached indices after storing new samples - self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples] + self.existing_indices[task_id] = [ + sample["sample_id"] for sample in all_samples + ] def cached(sampling_method: SamplingMethod = None): # noqa C901 @@ -381,12 +427,16 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cache: SampleCache = self._cache # Extract task names - task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs} + task_ids = { + cache.get_task_id(doc.task_name, sampling_method) for doc in docs + } # 1) Identify which samples must be processed because they are not cached docs_not_cached: List[Doc] tasks_with_cached_samples: Set[TaskID] - docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method) + docs_not_cached, tasks_with_cached_samples = ( + cache.get_samples_to_process_and_cache(docs, sampling_method) + ) # Log cache statistics cached_count = len(docs) - len(docs_not_cached) @@ -399,7 +449,8 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 new_results = [] if docs_not_cached: tasks_needing_sample_processing = { - cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached + cache.get_task_id(doc.task_name, sampling_method) + for doc in docs_not_cached } logger.info( f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}" @@ -415,15 +466,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 ) # 3) Create final results by pulling from newly saved file cache - final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method) + final_cached_results = cache.get_samples_from_cache( + docs, task_ids, sampling_method + ) # 4) We only keep samples with the correct sampling method final_results = [ - s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method + s + for s in final_cached_results + if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method ] if any(r is None for r in final_results): - raise ValueError("Problem while loading and aggregating items from cache.") + raise ValueError( + "Problem while loading and aggregating items from cache." + ) return final_results From 8256008dc8ecc1e503a1ec9b2194c591a7894f15 Mon Sep 17 00:00:00 2001 From: Duc Hoang Date: Thu, 6 Nov 2025 16:54:52 -0800 Subject: [PATCH 2/2] fixed styling --- src/lighteval/utils/cache_management.py | 100 ++++++------------------ 1 file changed, 24 insertions(+), 76 deletions(-) diff --git a/src/lighteval/utils/cache_management.py b/src/lighteval/utils/cache_management.py index 75ba2c50c..b6c0ca6ed 100644 --- a/src/lighteval/utils/cache_management.py +++ b/src/lighteval/utils/cache_management.py @@ -39,6 +39,7 @@ from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.utils import as_list + logger = logging.getLogger(__name__) @@ -57,9 +58,7 @@ def __str__(self): return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})" def __hash__(self): - return int.from_bytes( - hashlib.sha256(str(self).encode()).digest(), byteorder="big" - ) + return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big") class SampleCache: @@ -85,9 +84,7 @@ def __init__(self, model_config: ModelConfig): self.model_hash = self.get_model_hash(model_config) self.cache_dir = ( - Path(os.path.expanduser(self.model_config.cache_dir)) - / self.model_config.model_name - / self.model_hash + Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash ) self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -118,14 +115,10 @@ def _load_cached_indices(self) -> dict: # cache_file.parts gives all the subfolders of the url, up to the file name # last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2 task_name, task_hash = cache_file.parts[-3:-1] - sampling_method = SamplingMethod[ - cache_file.stem - ] # removes the file extension + sampling_method = SamplingMethod[cache_file.stem] # removes the file extension task_id = TaskID(task_name, task_hash, sampling_method) - full_dataset = load_dataset( - "parquet", data_files=str(cache_file), split="train" - ) + full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") sample_ids = [] for row in full_dataset: try: @@ -176,9 +169,7 @@ def _get_task_hash(self, full_task_name: str) -> str: task_configs: list[LightevalTaskConfig] = sorted( self.registry.task_to_configs[f"{task_suite}|{task_name}"] ) - config_str = "|".join( - [task_config.__str__(lite=True) for task_config in task_configs] - ) + config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs]) task_hash = hashlib.sha256(config_str.encode()).hexdigest()[:16] self._task_hashes[full_task_name] = task_hash return self._task_hashes[full_task_name] @@ -192,12 +183,7 @@ def get_cache_path(self, task_id: TaskID) -> Path: Returns: Path: Path to the cache file for the given task and sample type """ - return ( - self.cache_dir - / task_id.task_name - / task_id.task_hash - / f"{task_id.sampling_method.name}.parquet" - ) + return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet" def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID: """Returns a unique task indentifier. Depends on the task name, @@ -223,9 +209,7 @@ def get_sampling_method(self, sample: dict) -> str: return SamplingMethod.GENERATIVE return None - def _load_sample( - self, sample: pd.core.series.Series | dict - ) -> Union[dict, ModelResponse]: + def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]: """Load a sample from cached data based on sample type. Args: @@ -298,15 +282,11 @@ def get_samples_from_cache( continue cache_file = self.get_cache_path(task_id) try: - dataset = load_dataset( - "parquet", data_files=str(cache_file), split="train" - ) + dataset = load_dataset("parquet", data_files=str(cache_file), split="train") dataset_df = dataset.to_pandas().set_index("sample_id") task_datasets[task_id] = dataset_df except Exception as e: - logger.warning( - f"Error loading prediction cache for {str(task_id)}: {e}" - ) + logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}") # Build results list results = [] @@ -336,11 +316,7 @@ def cache_samples( # noqa C901 sample = self._dump_sample(result) processed_data[task_id].append({"sample_id": doc.id, "sample": sample}) - processed_data = { - task_id: task_data - for task_id, task_data in processed_data.items() - if task_data - } + processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data} # Concatenate it with existing data and save to file for task_id, task_data in processed_data.items(): @@ -354,49 +330,32 @@ def cache_samples( # noqa C901 existing_samples = {} if cache_file.exists(): try: - existing_dataset = load_dataset( - "parquet", data_files=str(cache_file), split="train" - ) + existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train") existing_data = existing_dataset.to_list() except KeyError: logger.info(f"No data was cached for {str(task_id)}") except Exception as e: - logger.error( - f"Error loading existing prediction cache for {str(task_id)}: {e}" - ) + logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}") - existing_samples = { - (row["sample_id"], sampling_method) for row in existing_data - } - if any( - (row["sample_id"], sampling_method) in existing_samples - for row in task_data - ): + existing_samples = {(row["sample_id"], sampling_method) for row in existing_data} + if any((row["sample_id"], sampling_method) in existing_samples for row in task_data): logger.warning( "Unexpected behavior: You have reprocessed already cached items - we will ignore the new version." ) # Merge with new data (new data overwrites existing) # We look at id + sampling method - new_data = [ - row - for row in task_data - if (row["sample_id"], sampling_method) not in existing_samples - ] + new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples] all_samples = existing_data + new_data # Save updated dataset dataset = Dataset.from_list(all_samples) dataset.to_parquet(str(cache_file)) - logger.info( - f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}." - ) + logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.") # Refresh cached indices after storing new samples - self.existing_indices[task_id] = [ - sample["sample_id"] for sample in all_samples - ] + self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples] def cached(sampling_method: SamplingMethod = None): # noqa C901 @@ -427,16 +386,12 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 cache: SampleCache = self._cache # Extract task names - task_ids = { - cache.get_task_id(doc.task_name, sampling_method) for doc in docs - } + task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs} # 1) Identify which samples must be processed because they are not cached docs_not_cached: List[Doc] tasks_with_cached_samples: Set[TaskID] - docs_not_cached, tasks_with_cached_samples = ( - cache.get_samples_to_process_and_cache(docs, sampling_method) - ) + docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method) # Log cache statistics cached_count = len(docs) - len(docs_not_cached) @@ -449,8 +404,7 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 new_results = [] if docs_not_cached: tasks_needing_sample_processing = { - cache.get_task_id(doc.task_name, sampling_method) - for doc in docs_not_cached + cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached } logger.info( f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}" @@ -466,21 +420,15 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901 ) # 3) Create final results by pulling from newly saved file cache - final_cached_results = cache.get_samples_from_cache( - docs, task_ids, sampling_method - ) + final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method) # 4) We only keep samples with the correct sampling method final_results = [ - s - for s in final_cached_results - if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method + s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method ] if any(r is None for r in final_results): - raise ValueError( - "Problem while loading and aggregating items from cache." - ) + raise ValueError("Problem while loading and aggregating items from cache.") return final_results