Skip to content

Commit

Permalink
Merge branch 'wenet-e2e:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ZailiWang authored Jun 5, 2024
2 parents 16cecf5 + 509d05d commit e2918e1
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 18 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,25 @@ git clone https://github.com/wenet-e2e/wenet.git
conda create -n wenet python=3.10
conda activate wenet
conda install conda-forge::sox
```

- Install CUDA: please follow this [link](https://icefall.readthedocs.io/en/latest/installation/index.html#id1), It's recomended to install CUDA 12.1
- Install torch and torchaudio, It's recomended to use 2.2.2+cu121:

``` sh
pip install torch==2.2.2+cu121 torchaudio==2.2.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html
```

- Install other python packages

``` sh
pip install -r requirements.txt
pre-commit install # for clean and tidy code
```

- Frequently Asked Questions (FAQs)

``` sh
# If you encounter sox compatibility issues
RuntimeError: set_buffer_size requires sox extension which is not available.
# ubuntu
Expand Down
3 changes: 2 additions & 1 deletion wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=self.resample_rate)
sample_frequency=self.resample_rate,
window_type="hamming")
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)],
dtype=torch.int64,
Expand Down
6 changes: 4 additions & 2 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def compute_fbank(sample,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
dither=0.0,
window_type="povey"):
""" Extract fbank
Args:
Expand All @@ -253,7 +254,8 @@ def compute_fbank(sample,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate)
sample_frequency=sample_rate,
window_type=window_type)
sample['feat'] = mat
return sample

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
configs['dataset_conf']['fbank_conf']['frame_shift'] = 10
configs['dataset_conf']['fbank_conf']['frame_length'] = 25
configs['dataset_conf']['fbank_conf']['dither'] = 0.1
configs['dataset_conf']['fbank_conf']['window_type'] = 'hamming'
configs['dataset_conf']['spec_sub'] = False
configs['dataset_conf']['spec_trim'] = False
configs['dataset_conf']['shuffle'] = True
Expand Down
1 change: 1 addition & 0 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
normalize_length=length_normalized_loss,
)

@torch.jit.unused
def forward(
self,
batch: dict,
Expand Down
24 changes: 12 additions & 12 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,21 @@ def __init__(self,
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, None, None)
# TODO(Mddct): 64 8 1 as args
self.max_right_rel_pos = 64
self.max_left_rel_pos = 8
self.max_right_rel_pos = 8
self.max_left_rel_pos = 64
self.rel_k_embed = torch.nn.Embedding(
self.max_left_rel_pos + self.max_right_rel_pos + 1, self.d_k)

def _relative_indices(self, length: int, device: torch.device):
indices = torch.arange(length, device=device).unsqueeze(0)
def _relative_indices(self, keys: torch.Tensor) -> torch.Tensor:
# (S, 1)
indices = torch.arange(keys.size(2), device=keys.device).unsqueeze(0)

# (S, S)
rel_indices = indices - indices.transpose(0, 1)

rel_indices = torch.clamp(rel_indices, -self.max_left_rel_pos,
self.max_right_rel_pos)

return rel_indices + self.max_left_rel_pos

def forward(
Expand All @@ -550,14 +555,9 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

rel_k = self.rel_k_embed(
self._relative_indices(k.size(2), query.device)) # (t2, t2, d_k)
rel_k = rel_k[-q.size(2):] # (t1, t2, d_k)
# b,h,t1,dk
rel_k = rel_k.unsqueeze(0).unsqueeze(0) # (1, 1, t1, t2, d_k)
q_expand = q.unsqueeze(3) # (batch, h, t1, 1, d_k)
rel_att_weights = (rel_k * q_expand).sum(-1).squeeze(
-1) # (batch, h, t1, t2)
rel_k = self.rel_k_embed(self._relative_indices(k)) # (t2, t2, d_k)
rel_k = rel_k[-q.size(2):]
rel_att_weights = torch.einsum("bhld,lrd->bhlr", q, rel_k)

if not self.use_sdpa:
scores = (torch.matmul(q, k.transpose(-2, -1)) +
Expand Down
4 changes: 2 additions & 2 deletions wenet/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,6 @@ def forward(
new_mask = ~make_pad_mask(seq_len, max_len=s // self.stride)
x = x.view(b, s // self.stride, self.idim * self.stride)
_, pos_emb = self.pos_enc_class(x, offset)
x = self.norm(x)
x = self.out(x)
x = self.norm(x)
x = self.out(x)
return x, pos_emb, new_mask.unsqueeze(1)
2 changes: 1 addition & 1 deletion wenet/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def cv(self, model, cv_data_loader, configs):

num_seen_utts += num_utts
total_acc.append(_dict['th_accuracy'].item(
) if _dict['th_accuracy'] is not None else 0.0)
) if _dict.get('th_accuracy', None) is not None else 0.0)
for loss_name, loss_value in _dict.items():
if loss_value is not None and "loss" in loss_name \
and torch.isfinite(loss_value):
Expand Down

0 comments on commit e2918e1

Please sign in to comment.