Skip to content

Commit

Permalink
Fix normalization of the coarse embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed May 14, 2024
1 parent bb3f44f commit f8472e3
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,16 @@ def fit(

# Normalize embeddings
source_coarse_embedding = (
source_coarse_embedding / source_coarse_embedding.max()
source_coarse_embedding
/ (source_coarse_embedding @ source_coarse_embedding.T)
.norm(dim=1)
.max()
)
target_coarse_embedding = (
target_coarse_embedding / target_coarse_embedding.max()
target_coarse_embedding
/ (target_coarse_embedding @ target_coarse_embedding.T)
.norm(dim=1)
.max()
)

# Sampled weights
Expand Down

0 comments on commit f8472e3

Please sign in to comment.