Skip to content

Commit

Permalink
[Tool]ImgDataset2WebDatasetMS (#130)
Browse files Browse the repository at this point in the history
* Create ImgDataset2WebDatasetMS.py

* pre-commit

Signed-off-by: lawrence-cj <[email protected]>

* add an example in README.md;

Signed-off-by: lawrence-cj <[email protected]>

* pre_commit

Signed-off-by: lawrence-cj <[email protected]>

---------

Signed-off-by: lawrence-cj <[email protected]>
Co-authored-by: lawrence-cj <[email protected]>
  • Loading branch information
Pevernow and lawrence-cj authored Jan 12, 2025
1 parent cb29d1c commit e8f6f32
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a

- 32GB VRAM is required for both 0.6B and 1.6B model's training

### 1). Train with image-text pairs in directory

We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure.

To launch Sana training, you will first need to prepare data in the following formats. [Here](asset/example_data) is an example for the data structure for reference.
Expand Down Expand Up @@ -313,6 +315,26 @@ bash train_scripts/train.sh \
--train.train_batch_size=8
```

### 2). Train with image-text pairs in directory

We also provide conversion scripts to convert your data to the required format. You can refer to the [data conversion scripts](asset/data_conversion_scripts) for more details.

```bash
python tools/convert_ImgDataset_to_WebDatasetMS_format.py
```

Then Sana's training can be launched via

```bash
# Example of training Sana 0.6B with 512x512 resolution from scratch
bash train_scripts/train.sh \
configs/sana_config/512ms/Sana_600M_img512.yaml \
--data.data_dir="[asset/example_data_tar]" \
--data.type=SanaWebDatasetMS \
--model.multi_scale=true \
--train.train_batch_size=32
```

# 💻 4. Metric toolkit

Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
Expand Down
72 changes: 72 additions & 0 deletions tools/convert_ImgDataset_to_WebDatasetMS_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# @Author: Pevernow ([email protected])
# @Date: 2025/1/5
# @License: (Follow the main project)
import json
import os
import tarfile

from PIL import Image, PngImagePlugin

PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks


def process_data(input_dir, output_tar_name="output.tar"):
"""
Processes a directory containing PNG files, generates corresponding JSON files,
and packages all files into a TAR file. It also counts the number of processed PNG images,
and saves the height and width of each PNG file to the JSON.
Args:
input_dir (str): The input directory containing PNG files.
output_tar_name (str): The name of the output TAR file (default is "output.tar").
"""
png_count = 0
json_files_created = []

for filename in os.listdir(input_dir):
if filename.lower().endswith(".png"):
png_count += 1
base_name = filename[:-4] # Remove the ".png" extension
txt_filename = os.path.join(input_dir, base_name + ".txt")
json_filename = base_name + ".json"
json_filepath = os.path.join(input_dir, json_filename)
png_filepath = os.path.join(input_dir, filename)

if os.path.exists(txt_filename):
try:
# Get the dimensions of the PNG image
with Image.open(png_filepath) as img:
width, height = img.size

with open(txt_filename, encoding="utf-8") as f:
caption_content = f.read().strip()

data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height}

with open(json_filepath, "w", encoding="utf-8") as outfile:
json.dump(data, outfile, indent=4, ensure_ascii=False)

print(f"Generated: {json_filename}")
json_files_created.append(json_filepath)

except Exception as e:
print(f"Error processing file {filename}: {e}")
else:
print(f"Warning: No corresponding TXT file found for {filename}.")

# Create a TAR file and include all files
with tarfile.open(output_tar_name, "w") as tar:
for item in os.listdir(input_dir):
item_path = os.path.join(input_dir, item)
tar.add(item_path, arcname=item) # arcname maintains the relative path of the file in the tar

print(f"\nAll files have been packaged into: {output_tar_name}")
print(f"Number of PNG images processed: {png_count}")


if __name__ == "__main__":
input_directory = input("Please enter the directory path containing PNG and TXT files: ")
output_tar_filename = (
input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar"
)
process_data(input_directory, output_tar_filename)

0 comments on commit e8f6f32

Please sign in to comment.