Skip to content

Commit

Permalink
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
twitter-team committed Mar 31, 2023
0 parents commit 78c3235
Show file tree
Hide file tree
Showing 111 changed files with 11,876 additions and 0 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Python package

on: [push]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v3
# - uses: pre-commit/[email protected]
# name: Run pre-commit checks (pylint/yapf/isort)
# env:
# SKIP: insert-license
# with:
# extra_args: --hook-stage push --all-files
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: "pip" # caching pip dependencies
- name: install packages
run: |
/usr/bin/python -m pip install --upgrade pip
pip install --no-deps -r images/requirements.txt
# - name: ssh access
# uses: lhotari/action-upterm@v1
# with:
# limit-access-to-actor: true
# limit-access-to-users: arashd
- name: run tests
run: |
# Environment variables are reset in between steps.
mkdir /tmp/github_testing
ln -s $GITHUB_WORKSPACE /tmp/github_testing/tml
export PYTHONPATH="/tmp/github_testing:$PYTHONPATH"
pytest -vv
35 changes: 35 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Mac
.DS_Store

# Vim
*.py.swp

This comment has been minimized.

Copy link
@801lang

801lang Oct 27, 2024

k


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
.hypothesis

venv
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/pausan/cblack
rev: release-22.3.0
hooks:
- id: cblack
name: cblack
description: "Black: The uncompromising Python code formatter - 2 space indent fork"
entry: cblack . -l 100
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
661 changes: 661 additions & 0 deletions COPYING

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions LICENSE.torchrec
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
A few files here (where it is specifically noted in comments) are based on code from torchrec but
adapted for our use. Torchrec license is below:


BSD 3-Clause License

Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
This project open sources some of the ML models used at Twitter.

Currently these are:

1. The "For You" Heavy Ranker (projects/home/recap).

2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387


This project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run

`./images/init_venv.sh` (Linux only).

The READMEs of each project contain instructions about how to run each project.
Empty file added common/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions common/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
import abc
from dataclasses import dataclass
import dataclasses

import torch
from torchrec.streamable import Pipelineable


class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict:
raise NotImplementedError

def to(self, device: torch.device, non_blocking: bool = False):
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)

def pin_memory(self):
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)

def __repr__(self) -> str:
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"

return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])

@property
def batch_size(self) -> int:
for tensor in self.as_dict().values():
if tensor is None:
continue
if not isinstance(tensor, torch.Tensor):
continue
return tensor.shape[0]
raise Exception("Could not determine batch size from tensors.")


@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
return list(cls.__dataclass_fields__.keys())

def as_dict(self):
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
if hasattr(self, feature_name)
}

@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
bases=(DataclassBatch,),
)

@staticmethod
def from_fields(name: str, fields: dict):
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
bases=(DataclassBatch,),
)


class DictionaryBatch(BatchBase, dict):
def as_dict(self) -> Dict:
return self
1 change: 1 addition & 0 deletions common/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
Loading

5 comments on commit 78c3235

@tronez
Copy link

@tronez tronez commented on 78c3235 Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello there

@IgorKowalczyk
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :shipit:

@parsakhaz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last

@GoombaProgrammer
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is twitter

@ehShahid
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

root@321 password

Please sign in to comment.