diff --git a/self_hosting_machinery/finetune/scripts/aux/file_sets_context.py b/self_hosting_machinery/finetune/scripts/aux/file_sets_context.py index dc27db51..42723c76 100644 --- a/self_hosting_machinery/finetune/scripts/aux/file_sets_context.py +++ b/self_hosting_machinery/finetune/scripts/aux/file_sets_context.py @@ -1,14 +1,16 @@ +import hashlib +import json +import os.path import random from pathlib import Path -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional import jsonlines from self_hosting_machinery.finetune.utils import traces -from self_hosting_machinery.scripts import env - from self_hosting_machinery.scripts.env import (TRAIN_UNFILTERED_FILEPATH, TEST_UNFILTERED_FILEPATH, - TRAIN_FILTERED_FILEPATH, TEST_FILTERED_FILEPATH) + TRAIN_FILTERED_FILEPATH, TEST_FILTERED_FILEPATH, + LOSS_PER_HASH_DB_FILEPATH) __all__ = ['FileSetsContext'] @@ -17,12 +19,34 @@ class FileSetsContext: TRAIN_FILES_MIN_NUMBER_WITH_TEST_SET = 4 TRAIN_FILES_MIN_NUMBER_WITHOUT_TEST_SET = 7 TEST_FILES_COUNT_WARNING = 64 + MAX_CACHED_LOSS_ROWS = 1_000_000 def __init__(self, autoselect_test_files_num: int): self._check_prerequisites() self.autoselect_test_files_num = autoselect_test_files_num self.train_files: List[Dict[str, Any]] = list(jsonlines.open(TRAIN_UNFILTERED_FILEPATH)) self.test_files: List[Dict[str, Any]] = list(jsonlines.open(TEST_UNFILTERED_FILEPATH)) + try: + hash_db = list(jsonlines.open(LOSS_PER_HASH_DB_FILEPATH)) + self.loss_per_hash_db = {(item["hash"], item["model"]): item for item in + hash_db[-FileSetsContext.MAX_CACHED_LOSS_ROWS:]} + except Exception: + self.loss_per_hash_db = dict() + Path(LOSS_PER_HASH_DB_FILEPATH).touch() + + def get_loss_by_content(self, model_name: str, content: str) -> Optional[float]: + h = hashlib.sha1(content.encode("utf-8")).hexdigest() + return self.loss_per_hash_db[(h, model_name)]["loss"] if (h, model_name) in self.loss_per_hash_db else None + + def add_content_loss_pair(self, model_name: str, content: str, loss: float): + row = { + "hash": hashlib.sha1(content.encode("utf-8")).hexdigest(), + "model": model_name, + "loss": loss + } + self.loss_per_hash_db[(row["hash"], row["model"])] = row + with open(LOSS_PER_HASH_DB_FILEPATH, "a") as f: + f.write(f"{json.dumps(row)}\n") def _check_prerequisites(self): if not Path(TRAIN_UNFILTERED_FILEPATH).exists(): @@ -42,29 +66,6 @@ def _check_prerequisites(self): traces.log(f"Manually selected test set contains {len(test_files)} files. " f"It could heavily slow down the training process on the next stage") - def is_up_to_date(self) -> bool: - unfiltered_train, filtered_train = ( - Path(TRAIN_UNFILTERED_FILEPATH), Path(TRAIN_FILTERED_FILEPATH) - ) - unfiltered_test, filtered_test = ( - Path(TEST_UNFILTERED_FILEPATH), Path(TEST_FILTERED_FILEPATH) - ) - how_to_filter = Path(env.CONFIG_HOW_TO_FILTER) - how_to_filetypes = Path(env.CONFIG_HOW_TO_FILETYPES) - - try: - has_updates = [ - unfiltered_train.lstat().st_mtime > filtered_train.lstat().st_mtime, - unfiltered_test.lstat().st_mtime > filtered_test.lstat().st_mtime, - ] - if how_to_filter.exists(): - has_updates.append(how_to_filter.lstat().st_mtime > filtered_train.lstat().st_mtime) - if how_to_filetypes.exists(): - has_updates.append(how_to_filetypes.lstat().st_mtime > filtered_train.lstat().st_mtime) - except OSError: - return False - return not any(has_updates) - def dump_filtered( self, files: List[Dict[str, Any]] diff --git a/self_hosting_machinery/finetune/scripts/finetune_filter.py b/self_hosting_machinery/finetune/scripts/finetune_filter.py index ffeda946..3eb0b66c 100644 --- a/self_hosting_machinery/finetune/scripts/finetune_filter.py +++ b/self_hosting_machinery/finetune/scripts/finetune_filter.py @@ -22,6 +22,10 @@ from self_hosting_machinery.finetune.utils.finetune_utils import (get_finetune_config, get_finetune_filter_config) +class InvalidLossValueException(Exception): + pass + + def _log_everywhere(message): logging.info(message) traces.log(message) @@ -52,6 +56,7 @@ def force_include_exclude_filter( @torch.inference_mode() def loss_based_filter( model_context: ModelContext, + dataset_context: FileSetsContext, files_status_context: FilesStatusContext, status_tracker: FinetuneFilterStatusTracker, *, @@ -66,17 +71,30 @@ def _get_file_loss(file) -> float: encoding=model_context.encoding ) for data in map(to_cuda, ds): - logits = model_context.forward(input=data['input']) - loss = model_context.loss( - logits=logits.to(torch.float32), - labels=data['labels'], - mask=data['mask'], - ).item() + content = model_context.encoding.decode(data['input'][0]) + maybe_loss = dataset_context.get_loss_by_content( + model_name=model_context.model_name, + content=content + ) + if maybe_loss is not None: + loss = maybe_loss + else: + logits = model_context.forward(input=data['input']) + loss = model_context.loss( + logits=logits.to(torch.float32), + labels=data['labels'], + mask=data['mask'], + ).item() + dataset_context.add_content_loss_pair( + model_name=model_context.model_name, + content=content, + loss=loss + ) if not (math.isnan(loss) or math.isinf(loss)): file_losses.append(loss) if len(file_losses) == 0: - raise Exception("small file") + raise InvalidLossValueException("small file") return sum(file_losses) / len(file_losses) @@ -87,7 +105,7 @@ def _get_file_loss(file) -> float: for file in train_files: try: file_loss = _get_file_loss(file) - except Exception as e: + except InvalidLossValueException as e: files_status_context.reject_file(file, reason=str(e)) continue @@ -129,6 +147,7 @@ def finetune_filter( _log_everywhere("Running perplexity based filter...") loss_based_filter( model_context=model_context, + dataset_context=dataset_context, files_status_context=file_status_context, status_tracker=status_tracker, filter_loss_threshold=finetune_filter_cfg['filter_loss_threshold'] @@ -161,10 +180,6 @@ def catch_sigusr1(signum, frame): file_sets_context = FileSetsContext( autoselect_test_files_num=finetune_filter_cfg.get("autoselect_test_files_num", 3) ) - if file_sets_context.is_up_to_date(): - logging.info("Train set filtering: nothing changed since last time, quit") - return - traces.log(textwrap.fill( f"This filter calculates perplexity for each file and filters out " f"files with perplexity larger than {finetune_filter_cfg['filter_loss_threshold']:.3f}.\n" diff --git a/self_hosting_machinery/scripts/env.py b/self_hosting_machinery/scripts/env.py index 1d853fde..e15b5825 100644 --- a/self_hosting_machinery/scripts/env.py +++ b/self_hosting_machinery/scripts/env.py @@ -18,6 +18,7 @@ TRAIN_FILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "train_set_filtered.jsonl") TEST_UNFILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "test_set.jsonl") TEST_FILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "test_set_filtered.jsonl") +LOSS_PER_HASH_DB_FILEPATH = os.path.join(DIR_UNPACKED, "loss_per_hash_db.json") CONFIG_ENUM_GPUS = os.path.join(DIR_CONFIG, "gpus_enum_result.out") CONFIG_BUSY_GPUS = os.path.join(DIR_CONFIG, "gpus_busy_result.out")