Skip to content

Commit

Permalink
allow custom schema
Browse files Browse the repository at this point in the history
  • Loading branch information
BramVanroy committed Jan 26, 2025
1 parent 0c3df50 commit f95d532
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/datatrove/pipeline/writers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
preupload_lfs_files,
)
from huggingface_hub.utils import HfHubHTTPError
from pyarrow.lib import Schema

from datatrove.io import DataFolderLike, get_datafolder
from datatrove.pipeline.writers import ParquetWriter
Expand All @@ -36,6 +37,7 @@ def __init__(
cleanup: bool = True,
expand_metadata: bool = True,
max_file_size: int = round(4.5 * 2**30), # 4.5GB, leave some room for the last batch
schema: Schema = None,
):
"""
This class is intended to upload VERY LARGE datasets. Consider using `push_to_hub` or just using a
Expand Down Expand Up @@ -73,6 +75,7 @@ def __init__(
adapter=adapter,
expand_metadata=expand_metadata,
max_file_size=max_file_size,
schema=schema,
)
self.operations = []
self._repo_init = False
Expand Down
6 changes: 5 additions & 1 deletion src/datatrove/pipeline/writers/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import Counter, defaultdict
from typing import IO, Callable, Literal

from pyarrow.lib import Schema

from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter

Expand All @@ -19,6 +21,7 @@ def __init__(
batch_size: int = 1000,
expand_metadata: bool = False,
max_file_size: int = 5 * 2**30, # 5GB
schema: Schema = None,
):
# Validate the compression setting
if compression not in {"snappy", "gzip", "brotli", "lz4", "zstd", None}:
Expand All @@ -40,6 +43,7 @@ def __init__(
self._file_counter = Counter()
self.compression = compression
self.batch_size = batch_size
self.schema = schema

def _on_file_switch(self, original_name, old_filename, new_filename):
"""
Expand Down Expand Up @@ -70,7 +74,7 @@ def _write(self, document: dict, file_handler: IO, filename: str):
if filename not in self._writers:
self._writers[filename] = pq.ParquetWriter(
file_handler,
schema=pa.RecordBatch.from_pylist([document]).schema,
schema=self.schema if self.schema is not None else pa.RecordBatch.from_pylist([document]).schema,
compression=self.compression,
)
self._batches[filename].append(document)
Expand Down

0 comments on commit f95d532

Please sign in to comment.