Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tools/export_onnx.py script #2357

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 12 additions & 52 deletions tools/export_onnx.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,14 @@
import torch
from infer.lib.infer_pack.models_onnx import SynthesizerTrnMsNSFsidM
import argparse
from infer.modules.onnx.export import export_onnx as eo

# Usage Example:
# python tools/export_onnx.py --input_model ~/models/my-model.pth --output_model ~/models/my-model.onnx
if __name__ == "__main__":
MoeVS = True # 模型是否为MoeVoiceStudio(原MoeSS)使用

ModelPath = "Shiroha/shiroha.pth" # 模型路径
ExportedPath = "model.onnx" # 输出路径
hidden_channels = 256 # hidden_channels,为768Vec做准备
cpt = torch.load(ModelPath, map_location="cpu")
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
print(*cpt["config"])

test_phone = torch.rand(1, 200, hidden_channels) # hidden unit
test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用)
test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹)
test_pitchf = torch.rand(1, 200) # nsf基频
test_ds = torch.LongTensor([0]) # 说话人ID
test_rnd = torch.rand(1, 192, 200) # 噪声(加入随机因子)

device = "cpu" # 导出时设备(不影响使用模型)

net_g = SynthesizerTrnMsNSFsidM(
*cpt["config"], is_half=False
) # fp32导出(C++要支持fp16必须手动将内存重新排列所以暂时不用fp16)
net_g.load_state_dict(cpt["weight"], strict=False)
input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
output_names = [
"audio",
]
# net_g.construct_spkmixmap(n_speaker) 多角色混合轨道导出
torch.onnx.export(
net_g,
(
test_phone.to(device),
test_phone_lengths.to(device),
test_pitch.to(device),
test_pitchf.to(device),
test_ds.to(device),
test_rnd.to(device),
),
ExportedPath,
dynamic_axes={
"phone": [1],
"pitch": [1],
"pitchf": [1],
"rnd": [2],
},
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names,
)
parser = argparse.ArgumentParser()
parser.add_argument("--input_model", type=str, help="input model path", default="Shiroha/shiroha.pth")
parser.add_argument("--output_model", type=str, help="output Onnx model path", default="model.onnx")
args = parser.parse_args()

ModelPath = args.input_model
ExportedPath = args.output_model
eo(ModelPath, ExportedPath)