From 096dad9dad7b3db109dc2feeb4250a715090427f Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Fri, 24 May 2024 17:39:15 +0800 Subject: [PATCH 1/8] [transformer] fix w2vbert attention init value (#2539) --- wenet/transformer/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index b8550bccf..3ac4883aa 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -525,8 +525,8 @@ def __init__(self, super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, value_bias, use_sdpa, None, None) # TODO(Mddct): 64 8 1 as args - self.max_right_rel_pos = 64 - self.max_left_rel_pos = 8 + self.max_right_rel_pos = 8 + self.max_left_rel_pos = 64 self.rel_k_embed = torch.nn.Embedding( self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k) From 9c68e78bc6cf58afb6e15c69c87ac733eba2528c Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Tue, 28 May 2024 12:54:37 +0800 Subject: [PATCH 2/8] fix th_accuracy when th_accuracy is misssing (#2541) --- wenet/utils/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 45f2739e2..999a23e17 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -140,7 +140,7 @@ def cv(self, model, cv_data_loader, configs): num_seen_utts += num_utts total_acc.append(_dict['th_accuracy'].item( - ) if _dict['th_accuracy'] is not None else 0.0) + ) if _dict.get('th_accuracy', None) is not None else 0.0) for loss_name, loss_value in _dict.items(): if loss_value is not None and "loss" in loss_name \ and torch.isfinite(loss_value): From 97ffee4481bb2a8f51619a145672a413301e7f68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Fri, 31 May 2024 23:28:30 +0800 Subject: [PATCH 3/8] [doc]: refine installation (#2546) * [doc]: refine installation * [doc]: refine installation * [doc]: refine installation --- README.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.md b/README.md index 95824b25a..e477e4951 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,25 @@ git clone https://github.com/wenet-e2e/wenet.git conda create -n wenet python=3.10 conda activate wenet conda install conda-forge::sox +``` + +- Install CUDA: please follow this [link](https://icefall.readthedocs.io/en/latest/installation/index.html#id1), It's recomended to install CUDA 12.1 +- Install torch and torchaudio, It's recomended to use 2.2.2+cu121: + +``` sh +pip install torch==2.2.2+cu121 torchaudio==2.2.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html +``` + +- Install other python packages + +``` sh pip install -r requirements.txt pre-commit install # for clean and tidy code +``` +- Frequently Asked Questions (FAQs) + +``` sh # If you encounter sox compatibility issues RuntimeError: set_buffer_size requires sox extension which is not available. # ubuntu From e197305c41c0f6bc196eab1f1c165fcd1c98f3a6 Mon Sep 17 00:00:00 2001 From: Yonnie1331 Date: Sat, 1 Jun 2024 15:33:05 +0900 Subject: [PATCH 4/8] unuse forward when jit exporting transducer (#2548) --- wenet/transducer/transducer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index 224077337..ff730cd50 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -90,6 +90,7 @@ def __init__( normalize_length=length_normalized_loss, ) + @torch.jit.unused def forward( self, batch: dict, From 76a49d7437d1320a6785bf26d90c922cf3aa7287 Mon Sep 17 00:00:00 2001 From: Lucky Wong Date: Mon, 3 Jun 2024 14:06:49 +0800 Subject: [PATCH 5/8] Using hamming window for Paraformer frontend. (#2549) Co-authored-by: Huang Lekai --- wenet/cli/paraformer_model.py | 3 ++- wenet/dataset/processor.py | 6 ++++-- .../convert_paraformer_to_wenet_config_and_ckpt.py | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index a80c360f0..a43814a3a 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -40,7 +40,8 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: frame_length=25, frame_shift=10, energy_floor=0.0, - sample_frequency=self.resample_rate) + sample_frequency=self.resample_rate, + window_type="hamming") feats = feats.unsqueeze(0) feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64, diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 4d3a80961..4de0a29cf 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -231,7 +231,8 @@ def compute_fbank(sample, num_mel_bins=23, frame_length=25, frame_shift=10, - dither=0.0): + dither=0.0, + window_type="povey"): """ Extract fbank Args: @@ -253,7 +254,8 @@ def compute_fbank(sample, frame_shift=frame_shift, dither=dither, energy_floor=0.0, - sample_frequency=sample_rate) + sample_frequency=sample_rate, + window_type=window_type) sample['feat'] = mat return sample diff --git a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py index 6dee02b08..859613391 100644 --- a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py +++ b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py @@ -140,6 +140,7 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, configs['dataset_conf']['fbank_conf']['frame_shift'] = 10 configs['dataset_conf']['fbank_conf']['frame_length'] = 25 configs['dataset_conf']['fbank_conf']['dither'] = 0.1 + configs['dataset_conf']['fbank_conf']['window_type'] = 'hamming' configs['dataset_conf']['spec_sub'] = False configs['dataset_conf']['spec_trim'] = False configs['dataset_conf']['shuffle'] = True From f28c86f7df6787e6f0876f991d4013d5f5ce5c5c Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Tue, 4 Jun 2024 10:17:05 +0800 Subject: [PATCH 6/8] [transformer] w2vbert attention reduce memory (#2550) * [transformer] w2vbert attention reduce memory * fix l r * align result between fairseq2 and wenet --- wenet/transformer/attention.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 3ac4883aa..ea234665e 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -530,11 +530,16 @@ def __init__(self, self.rel_k_embed = torch.nn.Embedding( self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k) - def _relative_indices(self, length: int, device: torch.device): - indices = torch.arange(length, device=device).unsqueeze(0) + def _relative_indices(self, keys: torch.Tensor) -> torch.Tensor: + # (S, 1) + indices = torch.arange(keys.size(2), device=keys.device).unsqueeze(0) + + # (S, S) rel_indices = indices - indices.transpose(0, 1) + rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos, self.max_right_rel_pos) + return rel_indices + self.max_left_rel_pos def forward( @@ -550,14 +555,9 @@ def forward( q, k, v = self.forward_qkv(query, key, value) k, v, new_cache = self._update_kv_and_cache(k, v, cache) - rel_k = self.rel_k_embed( - self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k) - rel_k = rel_k[-q.size(2):] # (t1, t2, d_k) - # b,h,t1,dk - rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k) - q_expand = q.unsqueeze(3) # (batch, h, t1, 1, d_k) - rel_att_weights = (rel_k * q_expand).sum(-1).squeeze( - -1) # (batch, h, t1, t2) + rel_k = self.rel_k_embed(self._relative_indices(k)) # (t2, t2, d_k) + rel_k = rel_k[-q.size(2):] + rel_att_weights = torch.einsum("bhld,lrd->bhlr", q, rel_k) if not self.use_sdpa: scores = (torch.matmul(q, k.transpose(-2, -1)) + From fcb4b9817361cb45c66b9135a8729838961532d3 Mon Sep 17 00:00:00 2001 From: xu-gaopeng Date: Tue, 4 Jun 2024 19:55:24 +0800 Subject: [PATCH 7/8] update StackNFramesSubsampling (#2551) --- wenet/transformer/subsampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py index 5619e8cf7..1d252b940 100644 --- a/wenet/transformer/subsampling.py +++ b/wenet/transformer/subsampling.py @@ -385,6 +385,6 @@ def forward( new_mask = ~make_pad_mask(seq_len, max_len=s // self.stride) x = x.view(b, s // self.stride, self.idim * self.stride) _, pos_emb = self.pos_enc_class(x, offset) - x = self.norm(x) - x = self.out(x) + x = self.norm(x) + x = self.out(x) return x, pos_emb, new_mask.unsqueeze(1) From 509d05d73d836edafce29d5dfce9d4184097f11f Mon Sep 17 00:00:00 2001 From: Zaili Wang <109502517+ZailiWang@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:04:32 +0800 Subject: [PATCH 8/8] upgrade IPEX runtime to r2.3 (#2538) --- runtime/core/cmake/ipex.cmake | 23 +++++++++++++---------- runtime/ipex/CMakeLists.txt | 6 +++--- runtime/ipex/README.md | 22 ++++++++++++++++------ runtime/ipex/docker/Dockerfile | 4 ++-- 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/runtime/core/cmake/ipex.cmake b/runtime/core/cmake/ipex.cmake index 14542f42e..33a1147bc 100644 --- a/runtime/core/cmake/ipex.cmake +++ b/runtime/core/cmake/ipex.cmake @@ -4,12 +4,15 @@ if(NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") message(FATAL_ERROR "Intel Extension For PyTorch supports only Linux for now") endif() +set(TORCH_VERSION "2.3.0") +set(IPEX_VERSION "2.3.0") + if(CXX11_ABI) - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcpu.zip") - set(URL_HASH "SHA256=137a842d1cf1e9196b419390133a1623ef92f8f84dc7a072f95ada684f394afd") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip") + set(URL_HASH "SHA256=f60009d2a74b6c8bdb174e398c70d217b7d12a4d3d358cd1db0690b32f6e193b") else() - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.0.1%2Bcpu.zip") - set(URL_HASH "SHA256=90d50350fd24ce5cf9dfbf47888d0cfd9f943eb677f481b86fe1b8e90f7fda5d") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip") + set(URL_HASH "SHA256=6b78aff4e586991bb2e040c02b2cfd73bc740059b9d12bcc1c1d7b3c86d2ab88") endif() FetchContent_Declare(libtorch URL ${LIBTORCH_URL} @@ -19,13 +22,13 @@ FetchContent_MakeAvailable(libtorch) find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH) if(CXX11_ABI) - set(LIBIPEX_URL "https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-cxx11-abi-2.0.100%2Bcpu.run") - set(URL_HASH "SHA256=f172d9ebc2ca0c39cc93bb395721194f79767e1bc3f82b13e1edc07d1530a600") - set(LIBIPEX_SCRIPT_NAME "libintel-ext-pt-cxx11-abi-2.0.100%2Bcpu.run") + set(LIBIPEX_URL "https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-cxx11-abi-${IPEX_VERSION}%2Bcpu.run") + set(URL_HASH "SHA256=8aa3c7c37f5cc2cba450947ca04f565fccb86c3bb98f592142375cfb9016f0d6") + set(LIBIPEX_SCRIPT_NAME "libintel-ext-pt-cxx11-abi-${IPEX_VERSION}%2Bcpu.run") else() - set(LIBIPEX_URL "https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-2.0.100%2Bcpu.run") - set(URL_HASH "SHA256=8392f965dd9b8f6c0712acbb805c7e560e4965a0ade279b47a5f5a8363888268") - set(LIBIPEX_SCRIPT_NAME "libintel-ext-pt-2.0.100%2Bcpu.run") + set(LIBIPEX_URL "https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/libipex/cpu/libintel-ext-pt-${IPEX_VERSION}%2Bcpu.run") + set(URL_HASH "SHA256=fecb6244a6cd38ca2d73a45272a6ad8527d1ec2caca512d919daa80adb621814") + set(LIBIPEX_SCRIPT_NAME "libintel-ext-pt-${IPEX_VERSION}%2Bcpu.run") endif() FetchContent_Declare(intel_ext_pt URL ${LIBIPEX_URL} diff --git a/runtime/ipex/CMakeLists.txt b/runtime/ipex/CMakeLists.txt index c51ff02f9..0b46931c8 100644 --- a/runtime/ipex/CMakeLists.txt +++ b/runtime/ipex/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(wenet VERSION 0.1) -option(CXX11_ABI "whether to use CXX11_ABI libtorch" OFF) +option(CXX11_ABI "whether to use CXX11_ABI libtorch" ON) option(GRAPH_TOOLS "whether to build TLG graph tools" OFF) option(BUILD_TESTING "whether to build unit test" ON) @@ -21,7 +21,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_base}) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -Ofast -mavx2 -mfma -pthread -fPIC") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Ofast -mavx2 -mfma -pthread -fPIC") # Include all dependency include(ipex) @@ -30,7 +30,7 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/kaldi ) - include(wetextprocessing) +include(wetextprocessing) # Build all libraries add_subdirectory(utils) diff --git a/runtime/ipex/README.md b/runtime/ipex/README.md index a75535120..c1dabc476 100644 --- a/runtime/ipex/README.md +++ b/runtime/ipex/README.md @@ -1,6 +1,6 @@ ## WeNet Server (x86) ASR Demo With Intel® Extension for PyTorch\* Optimization -[Intel® Extension for PyTorch\*](https://github.com/intel/intel-extension-for-pytorch) (IPEX) extends [PyTorch\*](https://pytorch.org/) with up-to-date optimization features for extra performance boost on Intel hardware. The optimizations take advantage of AVX-512, Vector Neural Network Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX) on Intel CPUs as well as Intel Xe Matrix Extensions (XMX) AI engines on Intel discrete GPUs. +[Intel® Extension for PyTorch\*](https://github.com/intel/intel-extension-for-pytorch) (IPEX) extends [PyTorch\*](https://pytorch.org/) with up-to-date optimization features for extra performance boost on Intel hardware. The optimizations take advantage of AVX-512, Vector Neural Network Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX) on Intel CPUs as well as Intel Xe Matrix Extensions (XMX) AI engines on Intel discrete GPUs. In the following we are introducing how to accelerate WeNet model inference performance on Intel® CPU machines with the adoption of Intel® Extension for PyTorch\*. The adoption mainly includes the export of pretrained models with IPEX optimization, as well as the buildup of WeNet runtime executables with IPEX C++ SDK. The buildup can be processed from local source code, or directly build and run a docker container in which the runtime binaries are ready. @@ -39,7 +39,8 @@ docker run --rm -v $PWD/docker_resource:/home/wenet/runtime/ipex/docker_resource ``` * Step 4. Test in docker container -``` + +```sh cd /home/wenet/runtime/ipex export GLOG_logtostderr=1 export GLOG_v=2 @@ -57,15 +58,18 @@ model_dir=docker_resource/model * Step 1. Environment Setup. WeNet code cloning and default dependencies installation + ``` sh git clone https://github.com/wenet-e2e/wenet cd wenet pip install -r requirements.txt ``` + Upgrading of PyTorch and TorchAudio, followed by the installation of IPEX + ``` sh -pip install torch==2.0.1 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cpu --force-reinstall -pip install intel_extension_for_pytorch==2.0.100 +pip install torch==2.3.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall +pip install intel_extension_for_pytorch==2.3.0 ``` Installation of related tools: Intel® OpenMP and TCMalloc @@ -83,6 +87,7 @@ based on the package manager of your system. * Step 3. Export the pretrained model with IPEX optimization. For exporting FP32 runtime model + ``` sh source examples/aishell/s0/path.sh export OMP_NUM_THREADS=1 @@ -91,7 +96,9 @@ python wenet/bin/export_ipex.py \ --checkpoint \ --output_file ``` + If you have an Intel® 4th Generation Xeon (Sapphire Rapids) server, you can export a BF16 runtime model and get better performance by virtue of [AMX instructions](https://en.wikipedia.org/wiki/Advanced_Matrix_Extensions) + ``` sh source examples/aishell/s0/path.sh export OMP_NUM_THREADS=1 @@ -101,7 +108,9 @@ python wenet/bin/export_ipex.py \ --output_file \ --dtype bf16 ``` + And for exporting int8 quantized runtime model + ``` sh source examples/aishell/s0/path.sh export OMP_NUM_THREADS=1 @@ -132,6 +141,7 @@ ipexrun --no-python \ --model_path $model_dir/ \ --unit_path $model_dir/units.txt 2>&1 | tee log.txt ``` -NOTE: Please refer [IPEX Launch Script Usage Guide](https://intel.github.io/intel-extension-for-pytorch/cpu/2.0.100+cpu/tutorials/performance_tuning/launch_script.html) for usage of advanced features. -For advanced usage of WeNet, such as building Web/RPC/HTTP services, please refer [LibTorch Tutorial](../libtorch#advanced-usage). The difference is that the executables should be invoked via IPEX launch script `ipexrun`. \ No newline at end of file +NOTE: Please refer [IPEX Launch Script Usage Guide](https://intel.github.io/intel-extension-for-pytorch/cpu/2.3.0+cpu/tutorials/performance_tuning/launch_script.html) for usage of advanced features. + +For advanced usage of WeNet, such as building Web/RPC/HTTP services, please refer [LibTorch Tutorial](../libtorch#advanced-usage). The difference is that the executables should be invoked via IPEX launch script `ipexrun`. diff --git a/runtime/ipex/docker/Dockerfile b/runtime/ipex/docker/Dockerfile index 854a33ceb..184872ba4 100644 --- a/runtime/ipex/docker/Dockerfile +++ b/runtime/ipex/docker/Dockerfile @@ -2,8 +2,8 @@ FROM ubuntu:22.04 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y git cmake wget build-essential python-is-python3 python3-pip google-perftools -RUN pip install torch==2.0.1 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cpu -RUN pip install intel_extension_for_pytorch==2.0.100 pyyaml six intel-openmp +RUN pip install torch==2.3.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cpu +RUN pip install intel_extension_for_pytorch==2.3.0 pyyaml six intel-openmp RUN ln -s /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 /usr/lib/x86_64-linux-gnu/libtcmalloc.so RUN git clone https://github.com/wenet-e2e/wenet.git /home/wenet