Skip to content

Commit

Permalink
Fix documentation, remove unused function, fix bucket reso for sd1.5,…
Browse files Browse the repository at this point in the history
… fix multiple datasets
  • Loading branch information
rockerBOO committed Jan 8, 2025
1 parent 1231f51 commit 556f3f1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 28 deletions.
6 changes: 3 additions & 3 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)

val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
Expand Down
25 changes: 4 additions & 21 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def split_train_val(paths: List[str], is_training_dataset: bool, validation_spli
Shuffle the dataset based on the validation_seed or the current random seed.
For example if the split of 0.2 of 100 images.
[0:79] = 80 training images
[0:80] = 80 training images
[80:] = 20 validation images
"""
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
logging.info(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
Expand Down Expand Up @@ -5900,8 +5900,8 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)


def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
timesteps = timesteps.long().to(device)
return timesteps

Expand Down Expand Up @@ -5964,23 +5964,6 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
return result


def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
"""
Add noise to the latents according to the noise magnitude at each timestep
(this is the forward diffusion process)
"""
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noisy_latents


def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
):
Expand Down
5 changes: 1 addition & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ def generate_step_logs(
return logs

def assert_extra_args(self, args, train_dataset_group):
# train_dataset_group.verify_bucket_reso_steps(64)
# TODO: Number of bucket reso steps may differ for each model, so a static number won't work
# and prevents models like SD1.5 with 64
pass
train_dataset_group.verify_bucket_reso_steps(32)

def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
Expand Down

0 comments on commit 556f3f1

Please sign in to comment.