Skip to content

Commit

Permalink
support dreambooth
Browse files Browse the repository at this point in the history
  • Loading branch information
Yimi81 committed Dec 26, 2023
1 parent 4094706 commit 9958b5c
Show file tree
Hide file tree
Showing 219 changed files with 643 additions and 456 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pip install -r requirements.txt
```

### Train on a single GPU

```bash
bash src/scripts/train_text_to_image_lora.sh
bash scripts/train_text_to_image_lora.sh
```
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions scripts/train_dreambooth.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="data/dog_dreambooth_test"
export CLASS_DIR="data/dog_dreambooth_test/path-to-class-images"
export OUTPUT_DIR="path-to-save-model"


CUDA_VISIBLE_DEVICES=0 accelerate launch --num_processes 1 src/train_bash.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dreambooth_data_dir=$INSTANCE_DIR \
--dreambooth_class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--do_train \
--finetuning_type dreambooth \
--dataloader_num_workers 8\
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--per_device_train_batch_size=1 \
--gradient_accumulation_steps=2 --gradient_checkpointing \
--use_8bit_adam \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--warmup_steps=0 \
--num_class_images=200 \
--max_steps=800 \
--report_to wandb \
--mixed_precision no\
--validation_prompt "a photo of sks dog in a bucket"
2 changes: 1 addition & 1 deletion src/difftuner/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from difftuner.data.loader import get_dataset, collate_fn, sdxl_collate_fn
from difftuner.data.loader import get_dataset, collate_fn, sdxl_collate_fn, dreambooth_collate_fn
from difftuner.data.preprocess import preprocess_dataset, sdxl_preprocess_dataset
50 changes: 45 additions & 5 deletions src/difftuner/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING, Any, Dict, List, Union
from typing import TYPE_CHECKING, Union

import torch

Expand All @@ -13,14 +13,15 @@


logger = get_logger(__name__)


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
return {"pixel_values": pixel_values, "input_ids": input_ids}


def sdxl_collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
Expand All @@ -35,14 +36,49 @@ def sdxl_collate_fn(examples):
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}



def dreambooth_collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]

input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]

if has_attention_mask:
attention_mask += [example["class_attention_mask"] for example in examples]

pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

input_ids = torch.cat(input_ids, dim=0)

batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}

if has_attention_mask:
attention_mask = torch.cat(attention_mask, dim=0)
batch["attention_mask"] = attention_mask

return batch

def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples

logger.info(f"{'-'*20} Loading dataset {'-'*20}")
logger.info(f"Loading dataset")

if data_args.dataset_name is not None:
dataset = load_dataset(
Expand All @@ -51,7 +87,7 @@ def get_dataset(
cache_dir=model_args.cache_dir,
data_dir=data_args.train_data_dir
)
else:
elif data_args.train_data_dir is not None:
data_files = {}
if data_args.train_data_dir is not None:
data_files["train"] = os.path.join(data_args.train_data_dir, "**")
Expand All @@ -60,8 +96,12 @@ def get_dataset(
data_files=data_files,
cache_dir=model_args.cache_dir
)
elif data_args.dreambooth_data_dir is not None:
dataset = None

if max_samples is not None: # truncate dataset
dataset = dataset.select(range(min(len(dataset), max_samples)))

logger.info(f"Finish loading dataset")

return dataset
Loading

0 comments on commit 9958b5c

Please sign in to comment.