Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Filtering improvements #245

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions self_hosting_machinery/finetune/scripts/aux/file_sets_context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import hashlib
import json
import os.path
import random
from pathlib import Path
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional

import jsonlines

from self_hosting_machinery.finetune.utils import traces
from self_hosting_machinery.scripts import env

from self_hosting_machinery.scripts.env import (TRAIN_UNFILTERED_FILEPATH, TEST_UNFILTERED_FILEPATH,
TRAIN_FILTERED_FILEPATH, TEST_FILTERED_FILEPATH)
TRAIN_FILTERED_FILEPATH, TEST_FILTERED_FILEPATH,
LOSS_PER_HASH_DB_FILEPATH)

__all__ = ['FileSetsContext']

Expand All @@ -17,12 +19,34 @@ class FileSetsContext:
TRAIN_FILES_MIN_NUMBER_WITH_TEST_SET = 4
TRAIN_FILES_MIN_NUMBER_WITHOUT_TEST_SET = 7
TEST_FILES_COUNT_WARNING = 64
MAX_CACHED_LOSS_ROWS = 1_000_000

def __init__(self, autoselect_test_files_num: int):
self._check_prerequisites()
self.autoselect_test_files_num = autoselect_test_files_num
self.train_files: List[Dict[str, Any]] = list(jsonlines.open(TRAIN_UNFILTERED_FILEPATH))
self.test_files: List[Dict[str, Any]] = list(jsonlines.open(TEST_UNFILTERED_FILEPATH))
try:
hash_db = list(jsonlines.open(LOSS_PER_HASH_DB_FILEPATH))
self.loss_per_hash_db = {(item["hash"], item["model"]): item for item in
hash_db[-FileSetsContext.MAX_CACHED_LOSS_ROWS:]}
except Exception:
self.loss_per_hash_db = dict()
Path(LOSS_PER_HASH_DB_FILEPATH).touch()

def get_loss_by_content(self, model_name: str, content: str) -> Optional[float]:
h = hashlib.sha1(content.encode("utf-8")).hexdigest()
return self.loss_per_hash_db[(h, model_name)]["loss"] if (h, model_name) in self.loss_per_hash_db else None

def add_content_loss_pair(self, model_name: str, content: str, loss: float):
row = {
"hash": hashlib.sha1(content.encode("utf-8")).hexdigest(),
"model": model_name,
"loss": loss
}
self.loss_per_hash_db[(row["hash"], row["model"])] = row
with open(LOSS_PER_HASH_DB_FILEPATH, "a") as f:
f.write(f"{json.dumps(row)}\n")

def _check_prerequisites(self):
if not Path(TRAIN_UNFILTERED_FILEPATH).exists():
Expand All @@ -42,29 +66,6 @@ def _check_prerequisites(self):
traces.log(f"Manually selected test set contains {len(test_files)} files. "
f"It could heavily slow down the training process on the next stage")

def is_up_to_date(self) -> bool:
unfiltered_train, filtered_train = (
Path(TRAIN_UNFILTERED_FILEPATH), Path(TRAIN_FILTERED_FILEPATH)
)
unfiltered_test, filtered_test = (
Path(TEST_UNFILTERED_FILEPATH), Path(TEST_FILTERED_FILEPATH)
)
how_to_filter = Path(env.CONFIG_HOW_TO_FILTER)
how_to_filetypes = Path(env.CONFIG_HOW_TO_FILETYPES)

try:
has_updates = [
unfiltered_train.lstat().st_mtime > filtered_train.lstat().st_mtime,
unfiltered_test.lstat().st_mtime > filtered_test.lstat().st_mtime,
]
if how_to_filter.exists():
has_updates.append(how_to_filter.lstat().st_mtime > filtered_train.lstat().st_mtime)
if how_to_filetypes.exists():
has_updates.append(how_to_filetypes.lstat().st_mtime > filtered_train.lstat().st_mtime)
except OSError:
return False
return not any(has_updates)

def dump_filtered(
self,
files: List[Dict[str, Any]]
Expand Down
39 changes: 27 additions & 12 deletions self_hosting_machinery/finetune/scripts/finetune_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from self_hosting_machinery.finetune.utils.finetune_utils import (get_finetune_config, get_finetune_filter_config)


class InvalidLossValueException(Exception):
pass


def _log_everywhere(message):
logging.info(message)
traces.log(message)
Expand Down Expand Up @@ -52,6 +56,7 @@ def force_include_exclude_filter(
@torch.inference_mode()
def loss_based_filter(
model_context: ModelContext,
dataset_context: FileSetsContext,
files_status_context: FilesStatusContext,
status_tracker: FinetuneFilterStatusTracker,
*,
Expand All @@ -66,17 +71,30 @@ def _get_file_loss(file) -> float:
encoding=model_context.encoding
)
for data in map(to_cuda, ds):
logits = model_context.forward(input=data['input'])
loss = model_context.loss(
logits=logits.to(torch.float32),
labels=data['labels'],
mask=data['mask'],
).item()
content = model_context.encoding.decode(data['input'][0])
maybe_loss = dataset_context.get_loss_by_content(
model_name=model_context.model_name,
content=content
)
if maybe_loss is not None:
loss = maybe_loss
else:
logits = model_context.forward(input=data['input'])
loss = model_context.loss(
logits=logits.to(torch.float32),
labels=data['labels'],
mask=data['mask'],
).item()
dataset_context.add_content_loss_pair(
model_name=model_context.model_name,
content=content,
loss=loss
)
if not (math.isnan(loss) or math.isinf(loss)):
file_losses.append(loss)

if len(file_losses) == 0:
raise Exception("small file")
raise InvalidLossValueException("small file")

return sum(file_losses) / len(file_losses)

Expand All @@ -87,7 +105,7 @@ def _get_file_loss(file) -> float:
for file in train_files:
try:
file_loss = _get_file_loss(file)
except Exception as e:
except InvalidLossValueException as e:
files_status_context.reject_file(file, reason=str(e))
continue

Expand Down Expand Up @@ -129,6 +147,7 @@ def finetune_filter(
_log_everywhere("Running perplexity based filter...")
loss_based_filter(
model_context=model_context,
dataset_context=dataset_context,
files_status_context=file_status_context,
status_tracker=status_tracker,
filter_loss_threshold=finetune_filter_cfg['filter_loss_threshold']
Expand Down Expand Up @@ -161,10 +180,6 @@ def catch_sigusr1(signum, frame):
file_sets_context = FileSetsContext(
autoselect_test_files_num=finetune_filter_cfg.get("autoselect_test_files_num", 3)
)
if file_sets_context.is_up_to_date():
logging.info("Train set filtering: nothing changed since last time, quit")
return

traces.log(textwrap.fill(
f"This filter calculates perplexity for each file and filters out "
f"files with perplexity larger than {finetune_filter_cfg['filter_loss_threshold']:.3f}.\n"
Expand Down
1 change: 1 addition & 0 deletions self_hosting_machinery/scripts/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TRAIN_FILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "train_set_filtered.jsonl")
TEST_UNFILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "test_set.jsonl")
TEST_FILTERED_FILEPATH = os.path.join(DIR_UNPACKED, "test_set_filtered.jsonl")
LOSS_PER_HASH_DB_FILEPATH = os.path.join(DIR_UNPACKED, "loss_per_hash_db.json")

CONFIG_ENUM_GPUS = os.path.join(DIR_CONFIG, "gpus_enum_result.out")
CONFIG_BUSY_GPUS = os.path.join(DIR_CONFIG, "gpus_busy_result.out")
Expand Down
Loading