From 556f3f1696eadcc16ee77425243b732a84c7a2aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 13:41:15 -0500 Subject: [PATCH] Fix documentation, remove unused function, fix bucket reso for sd1.5, fix multiple datasets --- library/config_util.py | 6 +++--- library/train_util.py | 25 ++++--------------------- train_network.py | 5 +---- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 63d28c969..de1e154a1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -481,9 +481,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) - datasets.append(dataset) + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/library/train_util.py b/library/train_util.py index b8894752e..62aae37ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -152,11 +152,11 @@ def split_train_val(paths: List[str], is_training_dataset: bool, validation_spli Shuffle the dataset based on the validation_seed or the current random seed. For example if the split of 0.2 of 100 images. - [0:79] = 80 training images + [0:80] = 80 training images [80:] = 20 validation images """ if validation_seed is not None: - print(f"Using validation seed: {validation_seed}") + logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) random.shuffle(paths) @@ -5900,8 +5900,8 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = timesteps.long().to(device) return timesteps @@ -5964,23 +5964,6 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler return result -def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: - """ - Add noise to the latents according to the noise magnitude at each timestep - (this is the forward diffusion process) - """ - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma - else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - return noisy_latents - - def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): diff --git a/train_network.py b/train_network.py index 199f589b0..7dbd12e88 100644 --- a/train_network.py +++ b/train_network.py @@ -125,10 +125,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - # train_dataset_group.verify_bucket_reso_steps(64) - # TODO: Number of bucket reso steps may differ for each model, so a static number won't work - # and prevents models like SD1.5 with 64 - pass + train_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)