Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 82 additions & 25 deletions src/lighteval/utils/cache_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.utils import as_list


logger = logging.getLogger(__name__)


Expand All @@ -58,7 +57,9 @@ def __str__(self):
return f"{self.task_name} ({self.task_hash}, {self.sampling_method.name})"

def __hash__(self):
return int.from_bytes(hashlib.sha256(str(self).encode()).digest(), byteorder="big")
return int.from_bytes(
hashlib.sha256(str(self).encode()).digest(), byteorder="big"
)


class SampleCache:
Expand All @@ -84,7 +85,9 @@ def __init__(self, model_config: ModelConfig):
self.model_hash = self.get_model_hash(model_config)

self.cache_dir = (
Path(os.path.expanduser(self.model_config.cache_dir)) / self.model_config.model_name / self.model_hash
Path(os.path.expanduser(self.model_config.cache_dir))
/ self.model_config.model_name
/ self.model_hash
)
self.cache_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -115,10 +118,14 @@ def _load_cached_indices(self) -> dict:
# cache_file.parts gives all the subfolders of the url, up to the file name
# last 3 are task_name/task_hash/file_name.parquet, so we take -3 and -2
task_name, task_hash = cache_file.parts[-3:-1]
sampling_method = SamplingMethod[cache_file.stem] # removes the file extension
sampling_method = SamplingMethod[
cache_file.stem
] # removes the file extension
task_id = TaskID(task_name, task_hash, sampling_method)

full_dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
full_dataset = load_dataset(
"parquet", data_files=str(cache_file), split="train"
)
sample_ids = []
for row in full_dataset:
try:
Expand Down Expand Up @@ -169,7 +176,9 @@ def _get_task_hash(self, full_task_name: str) -> str:
task_configs: list[LightevalTaskConfig] = sorted(
self.registry.task_to_configs[f"{task_suite}|{task_name}"]
)
config_str = "|".join([task_config.__str__(lite=True) for task_config in task_configs])
config_str = "|".join(
[task_config.__str__(lite=True) for task_config in task_configs]
)
task_hash = hashlib.sha256(config_str.encode()).hexdigest()[:16]
self._task_hashes[full_task_name] = task_hash
return self._task_hashes[full_task_name]
Expand All @@ -183,7 +192,12 @@ def get_cache_path(self, task_id: TaskID) -> Path:
Returns:
Path: Path to the cache file for the given task and sample type
"""
return self.cache_dir / task_id.task_name / task_id.task_hash / f"{task_id.sampling_method.name}.parquet"
return (
self.cache_dir
/ task_id.task_name
/ task_id.task_hash
/ f"{task_id.sampling_method.name}.parquet"
)

def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID:
"""Returns a unique task indentifier. Depends on the task name,
Expand All @@ -202,12 +216,16 @@ def get_task_id(self, task_name: str, sampling_method: SamplingMethod) -> TaskID

def get_sampling_method(self, sample: dict) -> str:
if len(sample.get("logprobs", [])) > 0:
if len(sample.get("text", [])) == 0:
return SamplingMethod.PERPLEXITY
return SamplingMethod.LOGPROBS
if len(sample.get("text", [])) > 0:
return SamplingMethod.GENERATIVE
return None

def _load_sample(self, sample: pd.core.series.Series | dict) -> Union[dict, ModelResponse]:
def _load_sample(
self, sample: pd.core.series.Series | dict
) -> Union[dict, ModelResponse]:
"""Load a sample from cached data based on sample type.

Args:
Expand Down Expand Up @@ -261,7 +279,10 @@ def get_samples_to_process_and_cache(
return docs_not_cached, set(tasks_with_cached_samples)

def get_samples_from_cache(
self, docs: List[Doc], task_ids: List[TaskID] | set[TaskID], sampling_method: SamplingMethod
self,
docs: List[Doc],
task_ids: List[TaskID] | set[TaskID],
sampling_method: SamplingMethod,
) -> List[dict | ModelResponse]:
"""Get cached samples for the given docs.
Warning: Assumes all docs and task_names provided are stored in cache, will fail otherwise.
Expand All @@ -277,11 +298,15 @@ def get_samples_from_cache(
continue
cache_file = self.get_cache_path(task_id)
try:
dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
dataset = load_dataset(
"parquet", data_files=str(cache_file), split="train"
)
dataset_df = dataset.to_pandas().set_index("sample_id")
task_datasets[task_id] = dataset_df
except Exception as e:
logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}")
logger.warning(
f"Error loading prediction cache for {str(task_id)}: {e}"
)

# Build results list
results = []
Expand Down Expand Up @@ -311,7 +336,11 @@ def cache_samples( # noqa C901
sample = self._dump_sample(result)

processed_data[task_id].append({"sample_id": doc.id, "sample": sample})
processed_data = {task_id: task_data for task_id, task_data in processed_data.items() if task_data}
processed_data = {
task_id: task_data
for task_id, task_data in processed_data.items()
if task_data
}

# Concatenate it with existing data and save to file
for task_id, task_data in processed_data.items():
Expand All @@ -325,32 +354,49 @@ def cache_samples( # noqa C901
existing_samples = {}
if cache_file.exists():
try:
existing_dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
existing_dataset = load_dataset(
"parquet", data_files=str(cache_file), split="train"
)
existing_data = existing_dataset.to_list()
except KeyError:
logger.info(f"No data was cached for {str(task_id)}")
except Exception as e:
logger.error(f"Error loading existing prediction cache for {str(task_id)}: {e}")
logger.error(
f"Error loading existing prediction cache for {str(task_id)}: {e}"
)

existing_samples = {(row["sample_id"], sampling_method) for row in existing_data}
if any((row["sample_id"], sampling_method) in existing_samples for row in task_data):
existing_samples = {
(row["sample_id"], sampling_method) for row in existing_data
}
if any(
(row["sample_id"], sampling_method) in existing_samples
for row in task_data
):
logger.warning(
"Unexpected behavior: You have reprocessed already cached items - we will ignore the new version."
)

# Merge with new data (new data overwrites existing)
# We look at id + sampling method
new_data = [row for row in task_data if (row["sample_id"], sampling_method) not in existing_samples]
new_data = [
row
for row in task_data
if (row["sample_id"], sampling_method) not in existing_samples
]
all_samples = existing_data + new_data

# Save updated dataset
dataset = Dataset.from_list(all_samples)
dataset.to_parquet(str(cache_file))

logger.info(f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}.")
logger.info(
f"Cached {len(all_samples)} samples of {str(task_id)} at {str(cache_file)}."
)

# Refresh cached indices after storing new samples
self.existing_indices[task_id] = [sample["sample_id"] for sample in all_samples]
self.existing_indices[task_id] = [
sample["sample_id"] for sample in all_samples
]


def cached(sampling_method: SamplingMethod = None): # noqa C901
Expand Down Expand Up @@ -381,12 +427,16 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
cache: SampleCache = self._cache

# Extract task names
task_ids = {cache.get_task_id(doc.task_name, sampling_method) for doc in docs}
task_ids = {
cache.get_task_id(doc.task_name, sampling_method) for doc in docs
}

# 1) Identify which samples must be processed because they are not cached
docs_not_cached: List[Doc]
tasks_with_cached_samples: Set[TaskID]
docs_not_cached, tasks_with_cached_samples = cache.get_samples_to_process_and_cache(docs, sampling_method)
docs_not_cached, tasks_with_cached_samples = (
cache.get_samples_to_process_and_cache(docs, sampling_method)
)

# Log cache statistics
cached_count = len(docs) - len(docs_not_cached)
Expand All @@ -399,7 +449,8 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
new_results = []
if docs_not_cached:
tasks_needing_sample_processing = {
cache.get_task_id(doc.task_name, sampling_method) for doc in docs_not_cached
cache.get_task_id(doc.task_name, sampling_method)
for doc in docs_not_cached
}
logger.info(
f"Cache: Starting to process {len(docs_not_cached)}/{len(docs)} samples (not found in cache) for tasks {','.join(str(t) for t in tasks_needing_sample_processing)}"
Expand All @@ -415,15 +466,21 @@ def wrapper(self, docs: Union[Doc, List[Doc]], *args, **kwargs): # noqa C901
)

# 3) Create final results by pulling from newly saved file cache
final_cached_results = cache.get_samples_from_cache(docs, task_ids, sampling_method)
final_cached_results = cache.get_samples_from_cache(
docs, task_ids, sampling_method
)

# 4) We only keep samples with the correct sampling method
final_results = [
s for s in final_cached_results if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method
s
for s in final_cached_results
if cache.get_sampling_method(cache._dump_sample(s)) == sampling_method
]

if any(r is None for r in final_results):
raise ValueError("Problem while loading and aggregating items from cache.")
raise ValueError(
"Problem while loading and aggregating items from cache."
)

return final_results

Expand Down
Loading