From e8f6f32cb9a56811e0308b5d40cde997b72c3b13 Mon Sep 17 00:00:00 2001 From: Pevernow <3450354617@qq.com> Date: Sun, 12 Jan 2025 15:18:03 +0800 Subject: [PATCH] [Tool]ImgDataset2WebDatasetMS (#130) * Create ImgDataset2WebDatasetMS.py * pre-commit Signed-off-by: lawrence-cj * add an example in README.md; Signed-off-by: lawrence-cj * pre_commit Signed-off-by: lawrence-cj --------- Signed-off-by: lawrence-cj Co-authored-by: lawrence-cj --- README.md | 22 ++++++ ...nvert_ImgDataset_to_WebDatasetMS_format.py | 72 +++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 tools/convert_ImgDataset_to_WebDatasetMS_format.py diff --git a/README.md b/README.md index edaf08a..a14afda 100644 --- a/README.md +++ b/README.md @@ -277,6 +277,8 @@ where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a - 32GB VRAM is required for both 0.6B and 1.6B model's training +### 1). Train with image-text pairs in directory + We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure. To launch Sana training, you will first need to prepare data in the following formats. [Here](asset/example_data) is an example for the data structure for reference. @@ -313,6 +315,26 @@ bash train_scripts/train.sh \ --train.train_batch_size=8 ``` +### 2). Train with image-text pairs in directory + +We also provide conversion scripts to convert your data to the required format. You can refer to the [data conversion scripts](asset/data_conversion_scripts) for more details. + +```bash +python tools/convert_ImgDataset_to_WebDatasetMS_format.py +``` + +Then Sana's training can be launched via + +```bash +# Example of training Sana 0.6B with 512x512 resolution from scratch +bash train_scripts/train.sh \ + configs/sana_config/512ms/Sana_600M_img512.yaml \ + --data.data_dir="[asset/example_data_tar]" \ + --data.type=SanaWebDatasetMS \ + --model.multi_scale=true \ + --train.train_batch_size=32 +``` + # 💻 4. Metric toolkit Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md). diff --git a/tools/convert_ImgDataset_to_WebDatasetMS_format.py b/tools/convert_ImgDataset_to_WebDatasetMS_format.py new file mode 100644 index 0000000..03fbfbd --- /dev/null +++ b/tools/convert_ImgDataset_to_WebDatasetMS_format.py @@ -0,0 +1,72 @@ +# @Author: Pevernow (wzy3450354617@gmail.com) +# @Date: 2025/1/5 +# @License: (Follow the main project) +import json +import os +import tarfile + +from PIL import Image, PngImagePlugin + +PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks + + +def process_data(input_dir, output_tar_name="output.tar"): + """ + Processes a directory containing PNG files, generates corresponding JSON files, + and packages all files into a TAR file. It also counts the number of processed PNG images, + and saves the height and width of each PNG file to the JSON. + + Args: + input_dir (str): The input directory containing PNG files. + output_tar_name (str): The name of the output TAR file (default is "output.tar"). + """ + png_count = 0 + json_files_created = [] + + for filename in os.listdir(input_dir): + if filename.lower().endswith(".png"): + png_count += 1 + base_name = filename[:-4] # Remove the ".png" extension + txt_filename = os.path.join(input_dir, base_name + ".txt") + json_filename = base_name + ".json" + json_filepath = os.path.join(input_dir, json_filename) + png_filepath = os.path.join(input_dir, filename) + + if os.path.exists(txt_filename): + try: + # Get the dimensions of the PNG image + with Image.open(png_filepath) as img: + width, height = img.size + + with open(txt_filename, encoding="utf-8") as f: + caption_content = f.read().strip() + + data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height} + + with open(json_filepath, "w", encoding="utf-8") as outfile: + json.dump(data, outfile, indent=4, ensure_ascii=False) + + print(f"Generated: {json_filename}") + json_files_created.append(json_filepath) + + except Exception as e: + print(f"Error processing file {filename}: {e}") + else: + print(f"Warning: No corresponding TXT file found for {filename}.") + + # Create a TAR file and include all files + with tarfile.open(output_tar_name, "w") as tar: + for item in os.listdir(input_dir): + item_path = os.path.join(input_dir, item) + tar.add(item_path, arcname=item) # arcname maintains the relative path of the file in the tar + + print(f"\nAll files have been packaged into: {output_tar_name}") + print(f"Number of PNG images processed: {png_count}") + + +if __name__ == "__main__": + input_directory = input("Please enter the directory path containing PNG and TXT files: ") + output_tar_filename = ( + input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar" + ) + process_data(input_directory, output_tar_filename)