Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 24, 2025
1 parent 842e927 commit 537f41d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
23 changes: 4 additions & 19 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,7 @@

from olmo_core.aliases import PathOrStr
from olmo_core.config import StrEnum
from olmo_core.io import (
clear_directory,
dir_is_empty,
file_exists,
is_url,
join_path,
normalize_path,
)
from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path
from olmo_core.utils import gc_cuda, get_element_size, wait_for

from ..utils import barrier, get_fs_local_rank, is_distributed
Expand Down Expand Up @@ -250,9 +243,9 @@ def load_model_and_optim_state(
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)
metadata = reader.read_metadata()

if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue
Expand Down Expand Up @@ -280,7 +273,6 @@ def load_model_and_optim_state(
)

if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue
Expand Down Expand Up @@ -579,15 +571,8 @@ def get_checkpoint_metadata(dir: PathOrStr) -> Metadata:
:param dir: The path/URL to the checkpoint.
"""
dir = normalize_path(dir)
try:
storage_reader = RemoteFileSystemReader(dir)
return storage_reader.read_metadata()
except FileNotFoundError as exc:
msg = f"'{dir}' does not appear to contain a state dict checkpoint."
suggested_dir = join_path(dir, "model_and_optim")
if file_exists(join_path(suggested_dir, ".metadata")):
msg += f" Did you mean to use '{suggested_dir}'?"
raise FileNotFoundError(msg) from exc
storage_reader = RemoteFileSystemReader(dir)
return storage_reader.read_metadata()


def _prepare_env_for_save(
Expand Down
17 changes: 13 additions & 4 deletions src/olmo_core/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
from olmo_core.distributed.utils import do_n_at_a_time
from olmo_core.exceptions import OLMoCheckpointError
from olmo_core.io import (
file_exists,
get_bytes_range,
init_client,
is_url,
join_path,
normalize_path,
resource_path,
upload,
Expand Down Expand Up @@ -378,10 +380,17 @@ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Fut

def read_metadata(self) -> Metadata:
if self._metadata is None:
with resource_path(self.path, ".metadata", local_cache=self.work_dir).open(
"rb"
) as metadata_file:
metadata = pickle.load(metadata_file)
try:
with resource_path(self.path, ".metadata", local_cache=self.work_dir).open(
"rb"
) as metadata_file:
metadata = pickle.load(metadata_file)
except FileNotFoundError as exc:
msg = f"'{dir}' does not appear to contain a distributed state dict/checkpoint."
suggested_dir = join_path(self.path, "model_and_optim")
if file_exists(join_path(suggested_dir, ".metadata")):
msg += f" Did you mean to use '{suggested_dir}'?"
raise FileNotFoundError(msg) from exc

if getattr(metadata, "storage_meta", None) is None:
metadata.storage_meta = StorageMeta()
Expand Down

0 comments on commit 537f41d

Please sign in to comment.