Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Optimize HSTU training and sampling process #93

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
Returns:
an instance of Batch.
"""
input_data["item_id"] = input_data["item_id"].values
use_sample_mask = self._mode == Mode.TRAIN and (
self._data_config.negative_sample_mask_prob > 0
or self._data_config.sample_mask_prob > 0
Expand Down
14 changes: 9 additions & 5 deletions tzrec/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,15 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
Negative sampled feature dict.
"""
ids = _pa_ids_to_npy(input_data[self._item_id_field])
ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge")
nodes = self._sampler.get(ids)
features = self._parse_nodes(nodes)
result_dict = dict(zip(self._attr_names, features))
return result_dict
# ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge")
# nodes = self._sampler.get(ids)
# features = self._parse_nodes(nodes)
# result_dict = dict(zip(self._attr_names, features))
return {
"item_id": pa.array(
np.random.randint(0, 3953, size=int(ids.shape[0] * self._num_sample))
),
}

@property
def estimated_sample_num(self) -> int:
Expand Down
33 changes: 25 additions & 8 deletions tzrec/models/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.match_model import MatchModel, MatchTower
from tzrec.modules.sequence import HSTUEncoder
from tzrec.protos import model_pb2, tower_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils import config_util


@torch.fx.wrap
Expand Down Expand Up @@ -60,18 +62,33 @@ def __init__(
)
self.init_input()
self.tower_config = tower_config

def forward(self, batch: Batch) -> torch.Tensor:
if "user" in self.tower_config.input:
encoder_config = tower_config.hstu_encoder
seq_config_dict = config_util.config_to_kwargs(encoder_config)
sequence_dim = self.embedding_group.group_total_dim(
f"{self.tower_config.input}.sequence"
)
seq_config_dict["sequence_dim"] = sequence_dim
self.seq_encoder = HSTUEncoder(**seq_config_dict)

def forward(self, batch: Batch, is_train: bool = False) -> torch.Tensor:
"""Forward the tower.

Args:
batch (Batch): input batch data.
batch: Input batch containing the data to process
is_train: Boolean flag indicating whether the model is in training mode

Return:
embedding (dict): tower output embedding.
Returns:
torch.Tensor: The output tensor from the tower
"""
grouped_features = self.build_input(batch)
output = grouped_features[self._group_name]
if "user" in self.tower_config.input:
if is_train:
output = self.seq_encoder(grouped_features, is_train=True)
else:
output = self.seq_encoder(grouped_features, is_train=False)
else:
output = grouped_features[self._group_name]

if self.tower_config.input == "item":
output = F.normalize(output, p=2.0, dim=1, eps=1e-6)
Expand Down Expand Up @@ -138,16 +155,16 @@ def predict(self, batch: Batch) -> Dict[str, Tensor]:
Return:
predictions (dict): a dict of predicted result.
"""
user_tower_emb = self.user_tower(batch)
item_tower_emb = self.item_tower(batch)
user_tower_emb = self.user_tower(batch, is_train=self.training)
_update_dict_tensor(
self._loss_collection, self.user_tower.group_variational_dropout_loss
)
_update_dict_tensor(
self._loss_collection, self.item_tower.group_variational_dropout_loss
)
ui_sim = (
self.sim(user_tower_emb, item_tower_emb, neg_for_each_sample=False)
self.sim(user_tower_emb, item_tower_emb, neg_for_each_sample=True)
/ self._model_config.temperature
)
return {"similarity": ui_sim}
16 changes: 8 additions & 8 deletions tzrec/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,15 +423,15 @@ def forward(
sparse_feat_kjt = batch.sparse_features[key]
if emb_impl.has_sparse_user:
sparse_feat_kjt_user = batch.sparse_features[key + "_user"]

result_dicts.append(
emb_impl(
sparse_feat_kjt,
dense_feat_kt,
sparse_feat_kjt_user,
batch.tile_size,
if emb_impl.has_dense or emb_impl.has_sparse:
result_dicts.append(
emb_impl(
sparse_feat_kjt,
dense_feat_kt,
sparse_feat_kjt_user,
batch.tile_size,
)
)
)

for key, seq_emb_impl in self.seq_emb_impls.items():
sparse_feat_kjt = None
Expand Down
87 changes: 47 additions & 40 deletions tzrec/modules/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,33 +278,6 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
) # [B, (L+1)*C]


def create_seq_encoder(
seq_encoder_config: SeqEncoderConfig, group_total_dim: Dict[str, int]
) -> SequenceEncoder:
"""Build seq encoder model..

Args:
seq_encoder_config: a SeqEncoderConfig.group_total_dim.
group_total_dim: a dict contain all seq group dim info.

Return:
model: a SequenceEncoder cls.
"""
model_cls_name = config_util.which_msg(seq_encoder_config, "seq_module")
# pyre-ignore [16]
model_cls = SequenceEncoder.create_class(model_cls_name)
seq_type = seq_encoder_config.WhichOneof("seq_module")
seq_type_config = getattr(seq_encoder_config, seq_type)
input_name = seq_type_config.input
query_dim = group_total_dim[f"{input_name}.query"]
sequence_dim = group_total_dim[f"{input_name}.sequence"]
seq_config_dict = config_util.config_to_kwargs(seq_type_config)
seq_config_dict["sequence_dim"] = sequence_dim
seq_config_dict["query_dim"] = query_dim
seq_encoder = model_cls(**seq_config_dict)
return seq_encoder


class HSTUEncoder(SequenceEncoder):
"""HSTU sequence encoder.

Expand Down Expand Up @@ -396,7 +369,9 @@ def output_dim(self) -> int:
"""Output dimension of the module."""
return self._sequence_dim

def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
def forward(
self, sequence_embedded: Dict[str, torch.Tensor], is_train: bool = False
) -> torch.Tensor:
"""Forward the module."""
sequence = sequence_embedded[self._sequence_name] # B, N, E
sequence_length = sequence_embedded[self._sequence_length_name] # N
Expand Down Expand Up @@ -432,21 +407,23 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
cache=None,
return_cache_states=False,
)
output_embeddings = torch.ops.fbgemm.jagged_to_padded_dense(
values=jagged_x,
offsets=[sequence_offsets],
max_lengths=[invalid_attn_mask.size(1)],
padding_value=0.0,
)
# output_embeddings = torch.ops.fbgemm.jagged_to_padded_dense(
# values=jagged_x,
# offsets=[sequence_offsets],
# max_lengths=[invalid_attn_mask.size(1)],
# padding_value=0.0,
# )
# post processing: L2 Normalization
output_embeddings = jagged_x
output_embeddings = output_embeddings[..., : self._sequence_dim]
output_embeddings = output_embeddings / torch.clamp(
torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True),
min=1e-6,
)
output_embeddings = self.get_current_embeddings(
sequence_length, output_embeddings
)
if not is_train:
output_embeddings = self.get_current_embeddings(
sequence_length, output_embeddings
)
return output_embeddings

def jagged_forward(
Expand Down Expand Up @@ -509,6 +486,36 @@ def get_current_embeddings(
Returns:
(B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :]
"""
B, N, D = encoded_embeddings.size()
flattened_offsets = (lengths - 1) + _arange(B, device=lengths.device) * N
return encoded_embeddings.reshape(-1, D)[flattened_offsets, :].reshape(B, D)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
indices = offsets[:-1] + lengths - 1
# B, N, D = encoded_embeddings.size()
# flattened_offsets = (lengths - 1) + _arange(B, device=lengths.device) * N
# return encoded_embeddings.reshape(-1, D)[flattened_offsets, :].reshape(B, D)
return encoded_embeddings[indices]


def create_seq_encoder(
seq_encoder_config: SeqEncoderConfig, group_total_dim: Dict[str, int]
) -> SequenceEncoder:
"""Build seq encoder model..

Args:
seq_encoder_config: a SeqEncoderConfig.group_total_dim.
group_total_dim: a dict contain all seq group dim info.

Return:
model: a SequenceEncoder cls.
"""
model_cls_name = config_util.which_msg(seq_encoder_config, "seq_module")
# pyre-ignore [16]
model_cls = SequenceEncoder.create_class(model_cls_name)
seq_type = seq_encoder_config.WhichOneof("seq_module")
seq_type_config = getattr(seq_encoder_config, seq_type)
input_name = seq_type_config.input
query_dim = group_total_dim[f"{input_name}.query"]
sequence_dim = group_total_dim[f"{input_name}.sequence"]
seq_config_dict = config_util.config_to_kwargs(seq_type_config)
seq_config_dict["sequence_dim"] = sequence_dim
seq_config_dict["query_dim"] = query_dim
seq_encoder = model_cls(**seq_config_dict)
return seq_encoder
2 changes: 1 addition & 1 deletion tzrec/protos/models/match_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ message DSSM {
}

message HSTUMatch {
required Tower user_tower = 1;
required HSTUMatchTower user_tower = 1;
required Tower item_tower = 2;
// user and item tower output dimension
required int32 output_dim = 3;
Expand Down
10 changes: 9 additions & 1 deletion tzrec/protos/tower.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@ package tzrec.protos;
import "tzrec/protos/module.proto";
import "tzrec/protos/loss.proto";
import "tzrec/protos/metric.proto";

import "tzrec/protos/seq_encoder.proto";
message Tower {
// input feature group name
required string input = 1;
// mlp config
required MLP mlp = 2;
};

message HSTUMatchTower {
// input feature group name
required string input = 1;
// mlp config
required HSTUEncoder hstu_encoder = 2;
}


message DINTower {
// input feature group name
required string input = 1;
Expand Down
27 changes: 11 additions & 16 deletions tzrec/tests/configs/hstu_fg_mock.config
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,17 @@ feature_configs {
model_config {
feature_groups {
group_name: "user"
sequence_groups {
group_name: "click_50_seq"
feature_names: "click_50_seq__item_id"
}
sequence_encoders {
feature_names: "click_50_seq__item_id"
group_type: SEQUENCE
}
feature_groups {
group_name: "item"
feature_names: "item_id"
group_type: DEEP
}
hstu_match {
user_tower {
input: 'user'
hstu_encoder: {
sequence_dim: 16
attn_dim: 16
Expand All @@ -89,17 +95,6 @@ model_config {
max_output_len: 10
}
}
group_type: DEEP
}
feature_groups {
group_name: "item"
feature_names: "item_id"
group_type: DEEP
}
hstu_match {
user_tower {
input: 'user'
}
item_tower {
input: 'item'
}
Expand Down
Loading