From 956c10b860b152fe89e36b5ae6bdc931b7588774 Mon Sep 17 00:00:00 2001 From: Volodymyr Date: Mon, 21 Oct 2024 14:42:13 +0000 Subject: [PATCH] Fix tools/export_onnx.py file --- tools/export_onnx.py | 64 +++++++++----------------------------------- 1 file changed, 12 insertions(+), 52 deletions(-) diff --git a/tools/export_onnx.py b/tools/export_onnx.py index 822e09ee1..b3894db51 100644 --- a/tools/export_onnx.py +++ b/tools/export_onnx.py @@ -1,54 +1,14 @@ -import torch -from infer.lib.infer_pack.models_onnx import SynthesizerTrnMsNSFsidM +import argparse +from infer.modules.onnx.export import export_onnx as eo +# Usage Example: +# python tools/export_onnx.py --input_model ~/models/my-model.pth --output_model ~/models/my-model.onnx if __name__ == "__main__": - MoeVS = True # 模型是否为MoeVoiceStudio(原MoeSS)使用 - - ModelPath = "Shiroha/shiroha.pth" # 模型路径 - ExportedPath = "model.onnx" # 输出路径 - hidden_channels = 256 # hidden_channels,为768Vec做准备 - cpt = torch.load(ModelPath, map_location="cpu") - cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk - print(*cpt["config"]) - - test_phone = torch.rand(1, 200, hidden_channels) # hidden unit - test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用) - test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹) - test_pitchf = torch.rand(1, 200) # nsf基频 - test_ds = torch.LongTensor([0]) # 说话人ID - test_rnd = torch.rand(1, 192, 200) # 噪声(加入随机因子) - - device = "cpu" # 导出时设备(不影响使用模型) - - net_g = SynthesizerTrnMsNSFsidM( - *cpt["config"], is_half=False - ) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16) - net_g.load_state_dict(cpt["weight"], strict=False) - input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"] - output_names = [ - "audio", - ] - # net_g.construct_spkmixmap(n_speaker) 多角色混合轨道导出 - torch.onnx.export( - net_g, - ( - test_phone.to(device), - test_phone_lengths.to(device), - test_pitch.to(device), - test_pitchf.to(device), - test_ds.to(device), - test_rnd.to(device), - ), - ExportedPath, - dynamic_axes={ - "phone": [1], - "pitch": [1], - "pitchf": [1], - "rnd": [2], - }, - do_constant_folding=False, - opset_version=16, - verbose=False, - input_names=input_names, - output_names=output_names, - ) + parser = argparse.ArgumentParser() + parser.add_argument("--input_model", type=str, help="input model path", default="Shiroha/shiroha.pth") + parser.add_argument("--output_model", type=str, help="output Onnx model path", default="model.onnx") + args = parser.parse_args() + + ModelPath = args.input_model + ExportedPath = args.output_model + eo(ModelPath, ExportedPath)