Skip to content

Commit

Permalink
Merge pull request #82 from deiteris/fix-merge-lab
Browse files Browse the repository at this point in the history
Fix merge lab
  • Loading branch information
deiteris authored Jun 15, 2024
2 parents 0f17420 + 5037bad commit bbcc0be
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
4 changes: 2 additions & 2 deletions server/restapi/MMVC_Rest_Fileuploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def get_onnx(self):
import traceback
traceback.print_exc()

def post_merge_models(self, request: str = Form(...)):
async def post_merge_models(self, request: str = Form(...)):
try:
print(request)
info = self.voiceChangerManager.merge_models(request)
info = await self.voiceChangerManager.merge_models(request)
json_compatible_item_data = jsonable_encoder(info)
return JSONResponse(content=json_compatible_item_data)
except Exception as e:
Expand Down
27 changes: 17 additions & 10 deletions server/voice_changer/RVC/modelMerger/MergeModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from collections import OrderedDict
import torch
from voice_changer.ModelSlotManager import ModelSlotManager

from safetensors import safe_open
from voice_changer.utils.ModelMerger import ModelMergerRequest
from settings import ServerSettings

import json

def merge_model(params: ServerSettings, request: ModelMergerRequest):
def extract(ckpt: Dict[str, Any]):
Expand All @@ -22,12 +22,19 @@ def extract(ckpt: Dict[str, Any]):

def load_weight(path: str):
print(f"Loading {path}...")
state_dict = torch.load(path, map_location="cpu")
if "model" in state_dict:
weight = extract(state_dict)
if path.endswith('.safetensors'):
with safe_open(path, 'pt', device='cpu') as cpt:
state_dict = cpt.metadata()
weight = { k: cpt.get_tensor(k) for k in cpt.keys() }
config = json.loads(state_dict['config'])
else:
weight = state_dict["weight"]
return weight, state_dict
state_dict = torch.load(path, map_location='cpu')
if "model" in state_dict:
weight = extract(state_dict)
else:
weight = state_dict["weight"]
config = state_dict['config']
return weight, state_dict, config

files = request.files
if len(files) == 0:
Expand All @@ -45,7 +52,7 @@ def load_weight(path: str):

filename = os.path.join(params.model_dir, str(f.slotIndex), os.path.basename(slotInfo.modelFile)) # slotInfo.modelFileはv.1.5.3.11以前はmodel_dirから含まれている。

weight, state_dict = load_weight(filename)
weight, state_dict, config = load_weight(filename)
weights.append(weight)
alphas.append(f.strength)

Expand All @@ -64,11 +71,11 @@ def load_weight(path: str):
merged["weight"][key] += weight[key] * alphas[i]
print("merge done. write metadata.")

merged["config"] = state_dict["config"]
merged["config"] = config
merged["params"] = state_dict["params"] if "params" in state_dict else None
merged["version"] = state_dict["version"] if "version" in state_dict else None
merged["sr"] = state_dict["sr"]
merged["f0"] = state_dict["f0"]
merged["f0"] = int(state_dict["f0"])
merged["info"] = state_dict["info"]
merged["embedder_name"] = state_dict["embedder_name"] if "embedder_name" in state_dict else None
merged["embedder_output_layer"] = state_dict["embedder_output_layer"] if "embedder_output_layer" in state_dict else None
Expand Down
7 changes: 4 additions & 3 deletions server/voice_changer/VoiceChangerManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,17 @@ def changeVoice(self, receivedData: AudioInOut):
def export2onnx(self):
return self.voiceChanger.export2onnx()

def merge_models(self, request: str):
async def merge_models(self, request: str):
# self.voiceChanger.merge_models(request)
req = json.loads(request)
req = ModelMergerRequest(**req)
req.files = [MergeElement(**f) for f in req.files]
slot = len(self.modelSlotManager.getAllSlotInfo())
# Slots are range is 0-499
slot = len(self.modelSlotManager.getAllSlotInfo()) - 1
if req.voiceChangerType == "RVC":
merged = RVCModelMerger.merge_models(self.params, req, slot)
loadParam = LoadModelParams(voiceChangerType="RVC", slot=slot, isSampleMode=False, sampleId="", files=[LoadModelParamFile(name=os.path.basename(merged), kind="rvcModel", dir="")], params={})
self.load_model(loadParam)
await self.load_model(loadParam)
return self.get_info()

def setEmitTo(self, emitTo: Callable[[Any], None]):
Expand Down

0 comments on commit bbcc0be

Please sign in to comment.