forked from hkchengrex/MMAudio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_eval.py
110 lines (90 loc) · 4.05 KB
/
batch_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
import os
from pathlib import Path
import hydra
import torch
import torch.distributed as distributed
import torchaudio
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from tqdm import tqdm
from mmaudio.data.data_setup import setup_eval_dataset
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
log = logging.getLogger()
@torch.inference_mode()
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
def main(cfg: DictConfig):
device = 'cuda'
torch.cuda.set_device(local_rank)
if cfg.model not in all_model_cfg:
raise ValueError(f'Unknown model variant: {cfg.model}')
model: ModelConfig = all_model_cfg[cfg.model]
model.download_if_needed()
seq_cfg = model.seq_cfg
run_dir = Path(HydraConfig.get().run.dir)
if cfg.output_name is None:
output_dir = run_dir / cfg.dataset
else:
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
output_dir.mkdir(parents=True, exist_ok=True)
# load a pretrained model
seq_cfg.duration = cfg.duration_s
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
log.info(f'Loaded weights from {model.model_path}')
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
# misc setup
rng = torch.Generator(device=device)
rng.manual_seed(cfg.seed)
fm = FlowMatching(cfg.sampling.min_sigma,
inference_mode=cfg.sampling.method,
num_steps=cfg.sampling.num_steps)
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False)
feature_utils = feature_utils.to(device).eval()
if cfg.compile:
net.preprocess_conditions = torch.compile(net.preprocess_conditions)
net.predict_flow = torch.compile(net.predict_flow)
feature_utils.compile()
dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
for batch in tqdm(loader):
audios = generate(batch.get('clip_video', None),
batch.get('sync_video', None),
batch.get('caption', None),
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg.cfg_strength,
clip_batch_size_multiplier=64,
sync_batch_size_multiplier=64)
audios = audios.float().cpu()
names = batch['name']
for audio, name in zip(audios, names):
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
def distributed_setup():
distributed.init_process_group(backend="nccl")
local_rank = distributed.get_rank()
world_size = distributed.get_world_size()
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
return local_rank, world_size
if __name__ == '__main__':
distributed_setup()
main()
# clean-up
distributed.destroy_process_group()