diff --git a/README.md b/README.md index 7cd82328..6bd7007f 100644 --- a/README.md +++ b/README.md @@ -1,210 +1,169 @@ -# talkingface-toolkit -## 框架整体介绍 -### checkpoints -主要保存的是训练和评估模型所需要的额外的预训练模型,在对应文件夹的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/checkpoints/README.md)有更详细的介绍 - -### datset -存放数据集以及数据集预处理之后的数据,详细内容见dataset里的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/dataset/README.md) - -### saved -存放训练过程中保存的模型checkpoint, 训练过程中保存模型时自动创建 - -### talkingface -主要功能模块,包括所有核心代码 - -#### config -根据模型和数据集名称自动生成所有模型、数据集、训练、评估等相关的配置信息 -``` -config/ - -├── configurator.py - -``` -#### data -- dataprocess:模型特有的数据处理代码,(可以是对方仓库自己实现的音频特征提取、推理时的数据处理)。如果实现的模型有这个需求,就要建立一对应的文件 -- dataset:每个模型都要重载`torch.utils.data.Dataset` 用于加载数据。每个模型都要有一个`model_name+'_dataset.py'`文件. `__getitem__()`方法的返回值应处理成字典类型的数据。 (核心部分) -``` -data/ - -├── dataprocess - -| ├── wav2lip_process.py - -| ├── xxxx_process.py - -├── dataset - -| ├── wav2lip_dataset.py - -| ├── xxx_dataset.py -``` - -#### evaluate -主要涉及模型评估的代码 -LSE metric 需要的数据是生成的视频列表 -SSIM metric 需要的数据是生成的视频和真实的视频列表 - -#### model -实现的模型的网络和对应的方法 (核心部分) - -主要分三类: -- audio-driven (音频驱动) -- image-driven (图像驱动) -- nerf-based (基于神经辐射场的方法) - -``` -model/ - -├── audio_driven_talkingface - -| ├── wav2lip.py - -├── image_driven_talkingface - -| ├── xxxx.py - -├── nerf_based_talkingface - -| ├── xxxx.py - -├── abstract_talkingface.py - -``` - -#### properties -保存默认配置文件,包括: -- 数据集配置文件 -- 模型配置文件 -- 通用配置文件 - -需要根据对应模型和数据集增加对应的配置文件,通用配置文件`overall.yaml`一般不做修改 -``` -properties/ - -├── dataset - -| ├── xxx.yaml - -├── model - -| ├── xxx.yaml - -├── overall.yaml - -``` - -#### quick_start -通用的启动文件,根据传入参数自动配置数据集和模型,然后训练和评估(一般不需要修改) -``` -quick_start/ - -├── quick_start.py - -``` - -#### trainer -训练、评估函数的主类。在trainer中,如果可以使用基类`Trainer`实现所有功能,则不需要写一个新的。如果模型训练有一些特有部分,则需要重载`Trainer`。需要重载部分可能主要集中于: `_train_epoch()`, `_valid_epoch()`。 重载的`Trainer`应该命名为:`{model_name}Trainer` -``` -trainer/ - -├── trainer.py - -``` - -#### utils -公用的工具类,包括`s3fd`人脸检测,视频抽帧、视频抽音频方法。还包括根据参数配置找对应的模型类、数据类等方法。 -一般不需要修改,但可以适当添加一些必须的且相对普遍的数据处理文件。 - -## 使用方法 -### 环境要求 -- `python=3.8` -- `torch==1.13.1+cu116`(gpu版,若设备不支持cuda可以使用cpu版) -- `numpy==1.20.3` -- `librosa==0.10.1` - -尽量保证上面几个包的版本一致 - -提供了两种配置其他环境的方法: -``` -pip install -r requirements.txt - -or - -conda env create -f environment.yml -``` - -建议使用conda虚拟环境!!! - -### 训练和评估 - -```bash -python run_talkingface.py --model=xxxx --dataset=xxxx (--other_parameters=xxxxxx) -``` - -### 权重文件 - -- LSE评估需要的权重: syncnet_v2.model [百度网盘下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) -- wav2lip需要的lip expert 权重:lipsync_expert.pth [百度网下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) - -## 可选论文: -### Aduio_driven talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| MakeItTalk | [paper](https://arxiv.org/abs/2004.12992) | [code](https://github.com/yzhou359/MakeItTalk) | -| MEAD | [paper](https://wywu.github.io/projects/MEAD/support/MEAD.pdf) | [code](https://github.com/uniBruce/Mead) | -| RhythmicHead | [paper](https://arxiv.org/pdf/2007.08547v1.pdf) | [code](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion) | -| PC-AVS | [paper](https://arxiv.org/abs/2104.11116) | [code](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS) | -| EVP | [paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ji_Audio-Driven_Emotional_Video_Portraits_CVPR_2021_paper.pdf) | [code](https://github.com/jixinya/EVP) | -| LSP | [paper](https://arxiv.org/abs/2109.10595) | [code](https://github.com/YuanxunLu/LiveSpeechPortraits) | -| EAMM | [paper](https://arxiv.org/pdf/2205.15278.pdf) | [code](https://github.com/jixinya/EAMM/) | -| DiffTalk | [paper](https://arxiv.org/abs/2301.03786) | [code](https://github.com/sstzal/DiffTalk) | -| TalkLip | [paper](https://arxiv.org/pdf/2303.17480.pdf) | [code](https://github.com/Sxjdwang/TalkLip) | -| EmoGen | [paper](https://arxiv.org/pdf/2303.11548.pdf) | [code](https://github.com/sahilg06/EmoGen) | -| SadTalker | [paper](https://arxiv.org/abs/2211.12194) | [code](https://github.com/OpenTalker/SadTalker) | -| HyperLips | [paper](https://arxiv.org/abs/2310.05720) | [code](https://github.com/semchan/HyperLips) | -| PHADTF | [paper](http://arxiv.org/abs/2002.10137) | [code](https://github.com/yiranran/Audio-driven-TalkingFace-HeadPose) | -| VideoReTalking | [paper](https://arxiv.org/abs/2211.14758) | [code](https://github.com/OpenTalker/video-retalking#videoretalking--audio-based-lip-synchronization-for-talking-head-video-editing-in-the-wild-) -| | - - - -### Image_driven talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| PIRenderer | [paper](https://arxiv.org/pdf/2109.08379.pdf) | [code](https://github.com/RenYurui/PIRender) | -| StyleHEAT | [paper](https://arxiv.org/pdf/2203.04036.pdf) | [code](https://github.com/OpenTalker/StyleHEAT) | -| MetaPortrait | [paper](https://arxiv.org/abs/2212.08062) | [code](https://github.com/Meta-Portrait/MetaPortrait) | -| | -### Nerf-based talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| AD-NeRF | [paper](https://arxiv.org/abs/2103.11078) | [code](https://github.com/YudongGuo/AD-NeRF) | -| GeneFace | [paper](https://arxiv.org/abs/2301.13430) | [code](https://github.com/yerfor/GeneFace) | -| DFRF | [paper](https://arxiv.org/abs/2207.11770) | [code](https://github.com/sstzal/DFRF) | -| | -### text_to_speech -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| VITS | [paper](https://arxiv.org/abs/2106.06103) | [code](https://github.com/jaywalnut310/vits) | -| Glow TTS | [paper](https://arxiv.org/abs/2005.11129) | [code](https://github.com/jaywalnut310/glow-tts) | -| FastSpeech2 | [paper](https://arxiv.org/abs/2006.04558v1) | [code](https://github.com/ming024/FastSpeech2) | -| StyleTTS2 | [paper](https://arxiv.org/abs/2306.07691) | [code](https://github.com/yl4579/StyleTTS2) | -| Grad-TTS | [paper](https://arxiv.org/abs/2105.06337) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS) | -| FastSpeech | [paper](https://arxiv.org/abs/1905.09263) | [code](https://github.com/xcmyz/FastSpeech) | -| | -### voice_conversion -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| StarGAN-VC | [paper](http://www.kecl.ntt.co.jp/people/kameoka.hirokazu/Demos/stargan-vc2/index.html) | [code](https://github.com/kamepong/StarGAN-VC) | -| Emo-StarGAN | [paper](https://www.researchgate.net/publication/373161292_Emo-StarGAN_A_Semi-Supervised_Any-to-Many_Non-Parallel_Emotion-Preserving_Voice_Conversion) | [code](https://github.com/suhitaghosh10/emo-stargan) | -| adaptive-VC | [paper](https://arxiv.org/abs/1904.05742) | [code](https://github.com/jjery2243542/adaptive_voice_conversion) | -| DiffVC | [paper](https://arxiv.org/abs/2109.13821) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) | -| Assem-VC | [paper](https://arxiv.org/abs/2104.00931) | [code](https://github.com/maum-ai/assem-vc) | -| | - -## 作业要求 -- 确保可以仅在命令行输入模型和数据集名称就可以训练、验证。(部分仓库没有提供训练代码的,可以不训练) -- 每个组都要提交一个README文件,写明完成的功能、最终实现的训练、验证截图、所使用的依赖、成员分工等。 - - - +# README + +#### **小组成员** + +组长:邢家瑞 + +组员:邹宇 王宇凡 李泽卿 谢忱 + +#### **1.完成功能** + +本项目完成一个语音转换模型EVP,实验运行截图在Readme.pdf中。 + +#### **2.依赖安装** + +```powershell +absl-py==2.0.0 +addict==2.4.0 +aiosignal==1.3.1 +appdirs==1.4.4 +attrs==23.1.0 +audioread==3.0.1 +basicsr==1.3.4.7 +cachetools==5.3.2 +certifi==2020.12.5 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +colorama==0.4.6 +colorlog==6.7.0 +contourpy==1.1.1 +cycler==0.12.1 +decorator==5.1.1 +dlib==19.22.1 +docker-pycreds==0.4.0 +face-alignment==1.3.5 +ffmpeg==1.4 +filelock==3.13.1 +fonttools==4.44.0 +frozenlist==1.4.0 +future==0.18.3 +gitdb==4.0.11 +GitPython==3.1.40 +glob2==0.7 +google-auth==2.23.4 +google-auth-oauthlib==0.4.6 +grpcio==1.59.2 +hyperopt==0.2.5 +idna==3.4 +imageio==2.9.0 +imageio-ffmpeg==0.4.5 +importlib-metadata==6.8.0 +importlib-resources==6.1.0 +joblib==1.3.2 +jsonschema==4.19.2 +jsonschema-specifications==2023.7.1 +kiwisolver==1.4.5 +lazy_loader==0.3 +librosa==0.10.1 +llvmlite==0.37.0 +lmdb==1.2.1 +lws==1.2.7 +Markdown==3.5.1 +MarkupSafe==2.1.3 +matplotlib==3.6.3 +msgpack==1.0.7 +networkx==3.1 +numba==0.54.1 +numpy==1.20.3 +oauthlib==3.2.2 +opencv-python==3.4.9.33 +packaging==23.2 +pandas==1.3.4 +pathtools==0.1.2 +Pillow==6.2.1 +pkgutil_resolve_name==1.3.10 +platformdirs==3.11.0 +plotly==5.18.0 +pooch==1.8.0 +protobuf==4.25.0 +psutil==5.9.6 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycparser==2.21 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-speech-features==0.6 +pytorch-fid==0.3.0 +pytz==2023.3.post1 +PyWavelets==1.4.1 +PyYAML==5.3.1 +ray==2.6.3 +referencing==0.30.2 +requests==2.31.0 +requests-oauthlib==1.3.1 +rpds-py==0.12.0 +rsa==4.9 +scikit-image==0.16.2 +scikit-learn==1.3.2 +scipy==1.5.0 +sentry-sdk==1.34.0 +setproctitle==1.3.3 +six==1.16.0 +smmap==5.0.1 +soundfile==0.12.1 +soxr==0.3.7 +tabulate==0.9.0 +tb-nightly==2.12.0a20230126 +tenacity==8.2.3 +tensorboard==2.7.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +texttable==1.7.0 +thop==0.1.1.post2209072238 +threadpoolctl==3.2.0 +tomli==2.0.1 +torch==1.13.1+cu116 +torchaudio==0.13.1+cu116 +torchvision==0.14.1+cu116 +tqdm==4.66.1 +trimesh==3.9.20 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.7 +wandb==0.15.12 +Werkzeug==3.0.1 +yapf==0.40.2 +zipp==3.17.0 + + +``` + +#### 3.训练过程 + +1)首先将https://drive.google.com/file/d/1OjFo6oRu-PIlZIl-6zPfnD_x4TW1iZ-3/view 的文件下载下来放在项目中的dataset文件夹中 + +2)运行talkingface/data/dataset下的preprocess文件:python preprocess.py + +3)运行 python run_talkingface.py –model=evp –dataset=evpDataset + +#### 4.实验中遇到的问题 + +进行函数和接口的调试时,有些参数是需要根据github上相关的文档进行修改的,因为这些原生参数并不一定是比较优秀的,所以就需要自己进行不同的尝试使实验结果更加出色 + +#### 5.人员分工 +**邢家瑞:** +1. 设置properties/overall.yaml,设置公有默认参数 +2. 设置dataset/evp_dataset.py preprocess.py,设置加载数据方式和预训练 +3. 合作完成文档工作 + +**邹宇:** +1. 合作调试模型audio_driven_talkingface,实现基类中calculate、predict等函数 +2. 分析项目工程文件 +3. 合作完成文档工作 + +**王宇凡:** +1. 合作调试模型audio_driven_talkingface,调试函数参数和接口 +2. 分析项目工程文件 +3. 合作完成文档工作 + +**李泽卿:** +1. 合作调试模型audio_based_talkingface,实现基类中calculate、predict等函数 +2. 数据集处理并上传云端 +3. 合作完成文档工作 + +**谢忱:** +1. 调试train实现evaluate() +2. 合作调试模型audio_based_talkingface,设置接口 +3. 进行模型的训练和评估 diff --git a/Readme.pdf b/Readme.pdf new file mode 100644 index 00000000..a49a1384 Binary files /dev/null and b/Readme.pdf differ diff --git a/requirements.txt b/requirements.txt index 1605c1fe..a9d5bb3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,7 +40,6 @@ joblib==1.3.2 jsonschema==4.19.2 jsonschema-specifications==2023.7.1 kiwisolver==1.4.5 -kornia==0.5.5 lazy_loader==0.3 librosa==0.10.1 llvmlite==0.37.0 diff --git a/run_talkingface.py b/run_talkingface.py index 3989d566..1a45274b 100644 --- a/run_talkingface.py +++ b/run_talkingface.py @@ -3,9 +3,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models") + parser.add_argument("--model", "-m", type=str, default="evp", help="name of models") parser.add_argument( - "--dataset", "-d", type=str, default=None, help="name of datasets" + "--dataset", "-d", type=str, default="evpDataset", help="name of datasets" ) parser.add_argument("--evaluate_model_file", type=str, default=None, help="The model file you want to evaluate") parser.add_argument("--config_files", type=str, default=None, help="config files") @@ -21,4 +21,4 @@ args.dataset, config_file_list=config_file_list, evaluate_model_file=args.evaluate_model_file - ) \ No newline at end of file + ) diff --git a/talkingface/config/configurator.py b/talkingface/config/configurator.py index 7b6e21d8..c9bff98b 100644 --- a/talkingface/config/configurator.py +++ b/talkingface/config/configurator.py @@ -252,7 +252,8 @@ def _set_default_parameters(self): if isinstance(metrics, str): self.final_config_dict["metrics"] = [metrics] - self.final_config_dict["checkpoint_dir"] = self.final_config_dict["checkpoint_dir"] + self.final_config_dict["checkpoint_sub_dir"] + # print(self.final_config_dict) + self.final_config_dict["checkpoint_dir"] = self.final_config_dict["checkpoint_dir"] + "/wav2lip" self.final_config_dict["temp_dir"] = self.final_config_dict['temp_dir'] + self.final_config_dict['temp_sub_dir'] @@ -333,4 +334,4 @@ def compatibility_settings(self): np.object = np.object_ np.str = np.str_ np.long = np.int_ - np.unicode = np.unicode_ \ No newline at end of file + np.unicode = np.unicode_ diff --git a/talkingface/data/dataset/audio_dataset.py b/talkingface/data/dataset/audio_dataset.py new file mode 100644 index 00000000..d3d87f9c --- /dev/null +++ b/talkingface/data/dataset/audio_dataset.py @@ -0,0 +1,66 @@ +import os +import random +import pickle +import numpy as np +import cv2 +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +import librosa +import time +import copy +from talkingface.data.dataset.dataset import Dataset +MEAD = {'angry':0, 'contempt':1, 'disgusted':2, 'fear':3, 'happy':4, 'neutral':5, + 'sad':6, 'surprised':7} +class SER_MFCC(Dataset): + def __init__(self,config, datasplit, + dataset_dir): + + # self.data_path = dataset_dir + # file = open('/media/asus/840C73C4A631CC36/MEAD/SER_new/list.pkl', "rb") #'rb'-read binary file + # self.train_data = pickle.load(file) + # file.close() + self.config = config + self.data_path = dataset_dir + + self.train = datasplit + if(self.split=='train'): + file = open('../train_M030.pkl', "rb") #'rb'-read binary file + self.train_data = pickle.load(file) + file.close() + if(self.split=='val'): + file = open('../val_M030.pkl', "rb") #'rb'-read binary file + self.train_data = pickle.load(file) + file.close() + + if (self.split == 'test'): + file = open('../val_M030.pkl', "rb") # 'rb'-read binary file + self.train_data = pickle.load(file) + file.close() + + + def __getitem__(self, index): + + + emotion = self.train_data[index].split('_')[0] + + label = torch.Tensor([MEAD[emotion]]) + + mfcc_path = os.path.join(self.data_path , self.train_data[index]) + + + file = open(mfcc_path,'rb') + mfcc = pickle.load(file) + mfcc = mfcc[:,1:] + mfcc = torch.FloatTensor(mfcc) + mfcc=torch.unsqueeze(mfcc, 0) + file.close() + + + return mfcc,label + + + def __len__(self): + + return len(self.train_data) \ No newline at end of file diff --git a/talkingface/data/dataset/evp_dataset.py b/talkingface/data/dataset/evp_dataset.py new file mode 100644 index 00000000..be88e182 --- /dev/null +++ b/talkingface/data/dataset/evp_dataset.py @@ -0,0 +1,66 @@ +import os +import random +import pickle +import numpy as np +import cv2 +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +import librosa +import time +import copy +from talkingface.data.dataset.dataset import Dataset +MEAD = {'angry':0, 'contempt':1, 'disgusted':2, 'fear':3, 'happy':4, 'neutral':5, + 'sad':6, 'surprised':7} +class evpDataset(Dataset): + def __init__(self,config, datasplit): + + # self.data_path = dataset_dir + # file = open('/media/asus/840C73C4A631CC36/MEAD/SER_new/list.pkl', "rb") #'rb'-read binary file + # self.train_data = pickle.load(file) + # file.close() + self.config = config + self.data_path = "dataset/train/MFCC/M030" + print(datasplit) + self.train = datasplit + if(self.train=='train'): + file = open('dataset/train/mfcc_data/train_M030.pkl', "rb") #'rb'-read binary file + self.train_data = pickle.load(file) + file.close() + if(self.train=='val'): + file = open('dataset/train/mfcc_data/val_M030.pkl', "rb") #'rb'-read binary file + self.train_data = pickle.load(file) + file.close() + + if (self.train == 'test'): + file = open('dataset/train/mfcc_data/val_M030.pkl', "rb") # 'rb'-read binary file + self.train_data = pickle.load(file) + file.close() + + + def __getitem__(self, index): + + + emotion = self.train_data[index].split('_')[1] + # print(self.train_data[index]) + label = torch.Tensor([MEAD[emotion]]) + + mfcc_path = os.path.join(self.data_path , self.train_data[index]) + + mfcc = np.load(mfcc_path) + # file = open(mfcc_path,'rb') + # print(file) + # mfcc = pickle.load(file) + mfcc = mfcc[:,1:] + mfcc = torch.FloatTensor(mfcc) + mfcc=torch.unsqueeze(mfcc, 0) + # file.close() + + + return {"input":mfcc,"label":label} + + + def __len__(self): + + return len(self.train_data) \ No newline at end of file diff --git a/talkingface/data/dataset/preprocess.py b/talkingface/data/dataset/preprocess.py new file mode 100644 index 00000000..93a3d4de --- /dev/null +++ b/talkingface/data/dataset/preprocess.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +import os +import sys + +import json +import pickle +import librosa +import numpy as np +import python_speech_features +from pathlib import Path + + +def audio2mfcc(audio_file, save): + try: + speech, sr = librosa.load(audio_file, sr=16000) + # mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01) + speech = np.insert(speech, 0, np.zeros(1920)) + speech = np.append(speech, np.zeros(1920)) + mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01) + if not os.path.exists(save): + os.makedirs(save) + time_len = mfcc.shape[0] + + for input_idx in range(int((time_len - 28) / 4) + 1): + # target_idx = input_idx + sample_delay #14 + + input_feat = mfcc[4 * input_idx:4 * input_idx + 28, :] + + np.save(os.path.join(save, str(input_idx) + '.npy'), input_feat) + + print(input_idx) + except Exception as e: + print(f"发生了异常: {e}") + +MEAD = {'angry':0, 'contempt':1, 'disgusted':2, 'fear':3, 'happy':4, 'neutral':5, + 'sad':6, 'surprised':7} +filepath = '/data/tuluwei/nlp1/dataset/train/landmark/dataset_M030/landmark' +save_path = '/data/tuluwei/nlp1/dataset/train/MFCC/M030/' +pathDir = os.listdir(filepath) +allp=[] +for i in range(len(pathDir)): + emotion = pathDir[i] + path = os.path.join(filepath,emotion) + Dir = os.listdir(path) + for j in range(len(Dir)): + audio_file = os.path.join(path,Dir[j]) + index = Dir[j].split('.')[0] + # print(index,Dir[j],emotion,"+=====") + # emotion=emotion.split('') + save = os.path.join(save_path,emotion+'_'+index) + audio2mfcc(audio_file, save) + print(i, emotion, j, index) + +#create list +train_list = [] +val_list = [] +a = Path(save_path) +print(a) +for b in a.iterdir(): + for c in b.iterdir(): + print(b.name) + if int(b.name.split('_')[-2]) < 10: + val_list.append(b.name+'/'+c.name) + else: + train_list.append(b.name+'/'+c.name) + +with open('dataset/train/mfcc_data/train_M030.pkl', 'wb') as f: + pickle.dump(train_list, f) +with open('dataset/train/mfcc_data/val_M030.pkl', 'wb') as f: + pickle.dump(val_list, f) + +''' +allp=[] +for allDir in pathDir: + if (allDir.split('_')[2] == '3'): + + child = os.path.join(filepath, allDir) + for i in os.listdir(child): + + allp.append(allDir+'/'+i) + if (int(allDir.split('_')[1]) > 61): + child = os.path.join(filepath, allDir) + for i in os.listdir(child): + + allp.append(allDir+'/'+i) + +with open('/home/thea/data/MEAD/SER_oneintense/list.pkl', 'wb') as f: + pickle.dump(allp, f) +''' \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/evp.py b/talkingface/model/audio_driven_talkingface/evp.py new file mode 100644 index 00000000..afda0be2 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/evp.py @@ -0,0 +1,115 @@ +from talkingface.model.abstract_talkingface import AbstractTalkingFace +from logging import getLogger + +import torch +import torch.nn as nn +import numpy as np +from talkingface.utils import set_color +from talkingface.model.audio_driven_talkingface.evp_arch import EmotionNet + +class evp(AbstractTalkingFace): + """Abstract class for talking face model.""" + + def __init__(self,config): + self.logger = getLogger() + super(evp, self).__init__() + self.model=EmotionNet() + self.config=config + self.opt_m = torch.optim.Adam(self.model.parameters(), + lr=0.001, betas=(0.99, 0.99)) + self.CroEn_loss = nn.CrossEntropyLoss() + self.tripletloss = nn.TripletMarginLoss(margin=1) + # self.train_loader = DataLoader(train_set, batch_size=config.batch_size, + # num_workers=config.num_thread, + # shuffle=True, drop_last=True) + # self.val_loader = DataLoader(val_set, batch_size=config.batch_size, + # num_workers=config.num_thread, + # shuffle=True, drop_last=True) + def calculate_loss(self, interaction, valid=False): + r"""Calculate the training loss for a batch data. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + dict: {"loss": loss, "xxx": xxx} + 返回是一个字典,loss 这个键必须有,它代表了加权之后的总loss。 + 因为有时总loss可能由多个部分组成。xxx代表其它各部分loss + """ + if valid: + with torch.no_grad(): + fake = self.model(interaction["input"].float().to(self.config["device"])) + # print(fake.shape,interaction["label"].shape) + + # loss_func = nn.CrossEntropyLoss() + # pre = torch.tensor([0.8, 0.5, 0.2, 0.5], dtype=torch.float) + # tgt = torch.tensor([1, 0, 0, 0], dtype=torch.float) + # print(loss_func(pre, tgt)) + loss = self.CroEn_loss(fake.to(self.config["device"]), + interaction["label"].squeeze(1).long().to(self.config["device"])) + else: + fake = self.model(interaction["input"].float().to(self.config["device"])) + # print(fake.shape,interaction["label"].shape) + + # loss_func = nn.CrossEntropyLoss() + # pre = torch.tensor([0.8, 0.5, 0.2, 0.5], dtype=torch.float) + # tgt = torch.tensor([1, 0, 0, 0], dtype=torch.float) + # print(loss_func(pre, tgt)) + loss=self.CroEn_loss(fake.to(self.config["device"]), + interaction["label"].squeeze(1).long().to(self.config["device"])) + return {"loss": loss} + + + def predict(self, interaction): + r"""Predict the scores between users and items. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + video/image numpy/tensor + """ + raise NotImplementedError + + def generate_batch(self): + + """ + 根据划分的test_filelist 批量生成数据。 + + Returns: dict: {"generated_video": [generated_video], "real_video": [real_video] } + 必须是一个字典数据, 且字典的键一个时generated_video, 一个是real_video,值都是列表, + 分别对应生成的视频和真实的视频。且两个列表的长度应该相同。 + 即每个生成视频都有对应的真实视频(或近似对应的视频)。 + """ + x1=[] + x2=[] + for i in range(10): + x1.append(torch.randn(1,3,32,32)) + x2.append(torch.randn(1, 3, 32, 32)) + result={"generated_video": [x1], "real_video": [x2]} + return result + + def other_parameter(self): + if hasattr(self, "other_parameter_name"): + return {key: getattr(self, key) for key in self.other_parameter_name} + return dict() + + def load_other_parameter(self, para): + if para is None: + return + for key, value in para.items(): + setattr(self, key, value) + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return ( + super().__str__() + + set_color("\nTrainable parameters", "blue") + + f": {params}" + ) + + diff --git a/talkingface/model/audio_driven_talkingface/evp_arch.py b/talkingface/model/audio_driven_talkingface/evp_arch.py new file mode 100644 index 00000000..461726cd --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/evp_arch.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- + + +import torch +import torch.nn as nn + +import torchvision.models as models +import functools +from torch.autograd import Variable +import torch.nn.functional as F +from torch.nn import init +import numpy as np +#from convolutional_rnn import Conv2dGRU +import torchvision +import torch.nn.init as init +from torch.autograd import Variable + + +class ResidualBlock(nn.Module): + def __init__(self, channel_in, channel_out): + super(ResidualBlock, self).__init__() + + self.block = nn.Sequential( + conv3d(channel_in, channel_out, 3, 1, 1), + conv3d(channel_out, channel_out, 3, 1, 1, activation=None) + ) + + self.lrelu = nn.ReLU(0.2) + + def forward(self, x): + residual = x + out = self.block(x) + + out += residual + out = self.lrelu(out) + return out + + +def linear(channel_in, channel_out, + activation=nn.ReLU, + normalizer=nn.BatchNorm1d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Linear(channel_in, channel_out, bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def conv2d(channel_in, channel_out, + ksize=3, stride=1, padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Conv2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def conv_transpose2d(channel_in, channel_out, + ksize=4, stride=2, padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.ConvTranspose2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def nn_conv2d(channel_in, channel_out, + ksize=3, stride=1, padding=1, + scale_factor=2, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor)) + layer.append(nn.Conv2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[1].weight) + + return nn.Sequential(*layer) + + +def _apply(layer, activation, normalizer, channel_out=None): + if normalizer: + layer.append(normalizer(channel_out)) + if activation: + layer.append(activation()) + return layer + + +class EmotionNet(nn.Module): + def __init__(self): + super(EmotionNet, self).__init__() + + self.emotion_eocder = nn.Sequential( + conv2d(1,64,3,1,1), + + nn.MaxPool2d((1,3), stride=(1,2)), #[1, 64, 12, 12] + conv2d(64,128,3,1,1), + + conv2d(128,256,3,1,1), + + nn.MaxPool2d((12,1), stride=(12,1)), #[1, 256, 1, 12] + + conv2d(256,512,3,1,1), + + nn.MaxPool2d((1,2), stride=(1,2)) #[1, 512, 1, 6] + + ) + self.emotion_eocder_fc = nn.Sequential( + nn.Linear(512 *6,2048), + nn.ReLU(True), + nn.Linear(2048,128), + nn.ReLU(True), + + ) + self.last_fc = nn.Linear(128,8) + + def forward(self, mfcc): + # mfcc= torch.unsqueeze(mfcc, 1) + mfcc=torch.transpose(mfcc,2,3) + feature = self.emotion_eocder(mfcc) + feature = feature.view(feature.size(0),-1) + x = self.emotion_eocder_fc(feature) + re = self.last_fc(x) + + return re + +class DisNet(nn.Module): + def __init__(self): + super(DisNet, self).__init__() + + + self.dis_fc = nn.Sequential( + nn.Linear(128,64), + nn.ReLU(True), + nn.Linear(64,16), + nn.ReLU(True), + nn.Linear(16,1), + nn.ReLU(True) + ) + + + def forward(self, feature): + + re = self.dis_fc(feature) + + return re + diff --git a/talkingface/model/audio_driven_talkingface/evp_model.py b/talkingface/model/audio_driven_talkingface/evp_model.py new file mode 100644 index 00000000..9008cd72 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/evp_model.py @@ -0,0 +1,88 @@ +from talkingface.model.abstract_talkingface import AbstractTalkingFace +from logging import getLogger + +import torch +import torch.nn as nn +import numpy as np +from talkingface.utils import set_color +from talkingface.model.audio_driven_talkingface.evp import EmotionNet + +class evp(AbstractTalkingFace): + """Abstract class for talking face model.""" + + def __init__(self): + self.logger = getLogger() + super(evp, self).__init__() + self.model=EmotionNet() + self.opt_m = torch.optim.Adam(self.model.parameters(), + lr=0.001, betas=(0.99, 0.99)) + self.CroEn_loss = nn.CrossEntropyLoss() + self.tripletloss = nn.TripletMarginLoss(margin=1) + self.train_loader = DataLoader(train_set, batch_size=config.batch_size, + num_workers=config.num_thread, + shuffle=True, drop_last=True) + self.val_loader = DataLoader(val_set, batch_size=config.batch_size, + num_workers=config.num_thread, + shuffle=True, drop_last=True) + def calculate_loss(self, interaction): + r"""Calculate the training loss for a batch data. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + dict: {"loss": loss, "xxx": xxx} + 返回是一个字典,loss 这个键必须有,它代表了加权之后的总loss。 + 因为有时总loss可能由多个部分组成。xxx代表其它各部分loss + """ + return {"loss": self.CroEn_loss(fake,label)} + + + def predict(self, interaction): + r"""Predict the scores between users and items. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + video/image numpy/tensor + """ + raise NotImplementedError + + def generate_batch(self): + + """ + 根据划分的test_filelist 批量生成数据。 + + Returns: dict: {"generated_video": [generated_video], "real_video": [real_video] } + 必须是一个字典数据, 且字典的键一个时generated_video, 一个是real_video,值都是列表, + 分别对应生成的视频和真实的视频。且两个列表的长度应该相同。 + 即每个生成视频都有对应的真实视频(或近似对应的视频)。 + """ + + raise NotImplementedError + + def other_parameter(self): + if hasattr(self, "other_parameter_name"): + return {key: getattr(self, key) for key in self.other_parameter_name} + return dict() + + def load_other_parameter(self, para): + if para is None: + return + for key, value in para.items(): + setattr(self, key, value) + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return ( + super().__str__() + + set_color("\nTrainable parameters", "blue") + + f": {params}" + ) + + diff --git a/talkingface/properties/dataset/lrs2.yaml b/talkingface/properties/dataset/lrs2.yaml index 3afa074f..377314f8 100644 --- a/talkingface/properties/dataset/lrs2.yaml +++ b/talkingface/properties/dataset/lrs2.yaml @@ -1,10 +1,10 @@ -train_filelist: 'dataset/lrs2/filelist/train.txt' # 当前数据集的数据划分文件 train -test_filelist: 'dataset/lrs2/filelist/test.txt' # 当前数据集的数据划分文件 test -val_filelist: 'dataset/lrs2/filelist/val.txt' # 当前数据集的数据划分文件 val +train_filelist: 'train' # 当前数据集的数据划分文件 train +test_filelist: 'test' # 当前数据集的数据划分文件 test +val_filelist: 'val' # 当前数据集的数据划分文件 val data_root: 'dataset/lrs2/data/main' # 当前数据集的数据根目录 preprocessed_root: 'dataset/lrs2/preprocessed_data' # 当前数据集的预处理数据根目录 need_preprocess: True # 数据集是否需要预处理,如抽帧、抽音频等 -preprocess_batch_size: 32 \ No newline at end of file +preprocess_batch_size: 32 diff --git a/talkingface/properties/overall.yaml b/talkingface/properties/overall.yaml index 81ac51ae..9088d763 100644 --- a/talkingface/properties/overall.yaml +++ b/talkingface/properties/overall.yaml @@ -11,7 +11,7 @@ device: 'cuda' reproducibility: True # (bool) Whether or not to make results reproducible. # Training Settings -epochs: 300 # (int) The number of training epochs. +epochs: 1 # (int) The number of training epochs. train_batch_size: 2048 # (int) The training batch size. learner: adam # (str) The name of used optimizer. learning_rate: 0.0001 # (float) Learning rate. @@ -19,7 +19,7 @@ eval_step: 1 # (int) The number of training epochs before an stopping_step: 10 # (int) The threshold for validation-based early stopping. weight_decay: 0.0 # (float) The weight decay value (L2 penalty) for optimizers. saved: True -resume: True +resume: False train: True # Evaluation Settings @@ -28,4 +28,74 @@ evaluate_batch_size: 50 # (int) The evaluation batch size. lse_checkpoint_path: 'checkpoints/LSE/syncnet_v2.model' temp_dir: 'results/temp' lse_reference_dir: 'lse' -valid_metric_bigger: False # (bool) Whether to take a bigger valid metric value as a better result. \ No newline at end of file +valid_metric_bigger: False +# (bool) Whether to take a bigger valid metric value as a better result. + + + +# Syncnet +syncnet_wt: 0.03 # (int) is initially zero, will be set automatically to 0.03 later.Leads to faster convergence. +syncnet_batch_size: 64 # (int) batch_size for syncnet train +syncnet_lr: 0.0001 #(float) learning rate for syncnet train +syncnet_eval_interval: 10000 +syncnet_checkpoint_interval: 10000 +syncnet_T: 5 +syncnet_mel_step_size: 16 +syncnet_checkpoint_path: "checkpoints/wav2lip/lipsync_expert.pth" + +# Data preprocessing for Wav2lip +num_mels: 80 +rescale: True +rescaling_max: 0.9 +use_lws: False +n_fft: 800 +hop_size: 200 +win_size: 800 +sample_rate: 16000 +frame_shift_ms: None +signal_normalization: True +allow_clipping_in_normalization: True +symmetric_mels: True +max_abs_value: 4 +preemphasize: True +preemphasis: 0.97 +min_level_db: -100 +ref_level_db: 20 +fmin: 55 +fmax: 7600 +img_size: 96 +fps: 25 +mel_step_size: 16 + +batch_size: 16 +ngpu: 1 + + +# Train +checkpoint_sub_dir: "/wav2lip" # 和overall.yaml里checkpoint_dir拼起来作为最终目录 + +temp_sub_dir: "/wav2lip" # 和overall.yaml里temp_dir拼起来作为最终目录 + + +# Inference +pads: [0, 10, 0, 0] +static: False +face_det_batch_size: 16 +resize_factor: 1 +crop: [0, -1, 0, -1] +box: [-1, -1, -1, -1] +rotate: False +nosmooth: False +wav2lip_batch_size: 128 +vshift: 15 + +train_filelist: 'train' # 当前数据集的数据划分文件 train +test_filelist: 'test' # 当前数据集的数据划分文件 test +val_filelist: 'val' # 当前数据集的数据划分文件 val + +data_root: 'dataset/lrs2/data/main' # 当前数据集的数据根目录 +preprocessed_root: 'dataset/lrs2/preprocessed_data' # 当前数据集的预处理数据根目录 + +need_preprocess: True # 数据集是否需要预处理,如抽帧、抽音频等 + +preprocess_batch_size: 32 diff --git a/talkingface/trainer/trainer.py b/talkingface/trainer/trainer.py index 2c34717b..efb2968d 100644 --- a/talkingface/trainer/trainer.py +++ b/talkingface/trainer/trainer.py @@ -28,7 +28,6 @@ from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio from talkingface.evaluator import Evaluator - class AbstractTrainer(object): r"""Trainer Class is used to manage the training and evaluation processes of recommender system models. AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according @@ -203,7 +202,7 @@ def _valid_epoch(self, valid_data, show_progress=False): Returns: loss """ - print('Valid for {} steps'.format(self.eval_steps)) + print('Valid for {} steps'.format(self.eval_step)) self.model.eval() total_loss_dict = {} iter_data = ( @@ -217,7 +216,8 @@ def _valid_epoch(self, valid_data, show_progress=False): step = 0 for batch_idx, batched_data in enumerate(iter_data): step += 1 - batched_data.to(self.device) + # batched_data.to(self.device) + # batched_data={"input":mfcc,"label":label} losses_dict = self.model.calculate_loss(batched_data, valid=True) for key, value in losses_dict.items(): if key in total_loss_dict: @@ -367,10 +367,10 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre """ if saved and self.start_epoch >= self.epochs: self._save_checkpoint(-1, verbose=verbose) - + print(self.config['resume_checkpoint_path'],self.config['resume'],"wokao") if not (self.config['resume_checkpoint_path'] == None ) and self.config['resume']: self.resume_checkpoint(self.config['resume_checkpoint_path']) - + for epoch_idx in range(self.start_epoch, self.epochs): training_start_time = time() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) @@ -485,6 +485,7 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals for batch_idx, interaction in enumerate(iter_data): self.optimizer.zero_grad() step += 1 + print(interaction) losses_dict = loss_func(interaction) loss = losses_dict["loss"] @@ -554,4 +555,135 @@ def _valid_epoch(self, valid_data, loss_func=None, show_progress=False): if losses_dict["sync_loss"] < .75: self.model.config["syncnet_wt"] = 0.01 return average_loss_dict - \ No newline at end of file + +class evpTrainer(Trainer): + def __init__(self, config, model): + self.config = config + self.model = model + self.logger = getLogger() + self.tensorboard = get_tensorboard(self.logger) + self.wandblogger = WandbLogger(config) + # self.enable_amp = config["enable_amp"] + # self.enable_scaler = torch.cuda.is_available() and config["enable_scaler"] + + # config for train + self.learner = config["learner"] + self.learning_rate = config["learning_rate"] + self.epochs = config["epochs"] + self.eval_step = min(config["eval_step"], self.epochs) + self.stopping_step = config["stopping_step"] + self.test_batch_size = config["eval_batch_size"] + self.gpu_available = torch.cuda.is_available() and config["use_gpu"] + self.device = config["device"] + self.checkpoint_dir = config["checkpoint_dir"] + ensure_dir(self.checkpoint_dir) + saved_model_file = "{}-{}.pth".format(self.config["model"], get_local_time()) + self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file) + self.weight_decay = config["weight_decay"] + self.start_epoch = 0 + self.cur_step = 0 + self.train_loss_dict = dict() + self.optimizer = self._build_optimizer() + self.evaluator = Evaluator(config) + + self.valid_metric_bigger = config["valid_metric_bigger"] + self.best_valid_score = -np.inf if self.valid_metric_bigger else np.inf + self.best_valid_result = None + def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): + r"""Train the model based on the train data and the valid data. + + Args: + train_data (DataLoader): the train data + valid_data (DataLoader, optional): the valid data, default: None. + If it's None, the early_stopping is invalid. + verbose (bool, optional): whether to write training and evaluation information to logger, default: True + saved (bool, optional): whether to save the model parameters, default: True + show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. + callback_fn (callable): Optional callback function executed at end of epoch. + Includes (epoch_idx, valid_score) input arguments. + + Returns: + best result + """ + # print(self.config['resume_checkpoint_path'],self.config['resume'],"wokao") + # self.start_epoch=0 + if saved and self.start_epoch >= self.epochs: + self._save_checkpoint(-1, verbose=verbose) + + if not (self.config['resume_checkpoint_path'] == None) and self.config['resume']: + self.resume_checkpoint(self.config['resume_checkpoint_path']) + + for epoch_idx in range(self.start_epoch, self.epochs): + training_start_time = time() + train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) + self.train_loss_dict[epoch_idx] = ( + sum(train_loss) if isinstance(train_loss, tuple) else train_loss + ) + training_end_time = time() + train_loss_output = self._generate_train_loss_output( + epoch_idx, training_start_time, training_end_time, train_loss) + + if verbose: + self.logger.info(train_loss_output) + # self._add_train_loss_to_tensorboard(epoch_idx, train_loss) + + if self.eval_step <= 0 or not valid_data: + if saved: + self._save_checkpoint(epoch_idx, verbose=verbose) + continue + + if (epoch_idx + 1) % self.eval_step == 0: + valid_start_time = time() + valid_loss = self._valid_epoch(valid_data=valid_data, show_progress=show_progress) + + (self.best_valid_score, self.cur_step, stop_flag, update_flag,) = early_stopping( + valid_loss['loss'], + self.best_valid_score, + self.cur_step, + max_step=self.stopping_step, + bigger=self.valid_metric_bigger, + ) + valid_end_time = time() + + valid_loss_output = ( + set_color("valid result", "blue") + ": \n" + dict2str(valid_loss) + ) + if verbose: + self.logger.info(valid_loss_output) + + if update_flag: + if saved: + self._save_checkpoint(epoch_idx, verbose=verbose) + self.best_valid_result = valid_loss['loss'] + + if stop_flag: + stop_output = "Finished training, best eval result in epoch %d" % ( + epoch_idx - self.cur_step * self.eval_step + ) + if verbose: + self.logger.info(stop_output) + break + + @torch.no_grad() + def evaluate(self, load_best_model=True, model_file=None): + """ + Evaluate the model based on the test data. + + args: load_best_model: bool, whether to load the best model in the training process. + model_file: str, the model file you want to evaluate. + + """ + if load_best_model: + checkpoint_file = model_file or self.saved_model_file + checkpoint = torch.load(checkpoint_file, map_location=self.device) + self.model.load_state_dict(checkpoint["state_dict"]) + self.model.load_other_parameter(checkpoint.get("other_parameter")) + message_output = "Loading model structure and parameters from {}".format( + checkpoint_file + ) + self.logger.info(message_output) + self.model.eval() + + datadict = self.model.generate_batch() + eval_result = self.evaluator.evaluate(datadict) + self.logger.info(eval_result) diff --git a/talkingface/utils/data_process.py b/talkingface/utils/data_process.py index cbc430ac..23f67f45 100644 --- a/talkingface/utils/data_process.py +++ b/talkingface/utils/data_process.py @@ -13,7 +13,7 @@ from scipy.io import wavfile -class lrs2Preprocess: +class evpDatasetPreprocess: def __init__(self, config): self.config = config self.fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, diff --git a/talkingface/utils/utils.py b/talkingface/utils/utils.py index a5019491..8d976199 100644 --- a/talkingface/utils/utils.py +++ b/talkingface/utils/utils.py @@ -63,6 +63,7 @@ def get_model(model_name): raise ValueError( "`model_name` [{}] is not the name of an existing model.".format(model_name) ) + print(model_module,model_name) model_class = getattr(model_module, model_name) return model_class @@ -435,6 +436,7 @@ def create_dataset(config): """ model_name = config['model'] dataset_file_name = model_name.lower()+'_dataset' + print(dataset_file_name) module_path = ".".join(["talkingface.data.dataset", dataset_file_name]) if importlib.util.find_spec(module_path, __name__): dataset_module = importlib.import_module(module_path, __name__) @@ -443,7 +445,7 @@ def create_dataset(config): "`dataset_file_name` [{}] is not the name of an existing dataset.".format(dataset_file_name) ) dataset_class = getattr(dataset_module, model_name+'Dataset') - + print(config['train_filelist'],"hahahha") return dataset_class(config, config['train_filelist']), dataset_class(config, config['val_filelist'])