-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: Python package | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
# - uses: pre-commit/[email protected] | ||
# name: Run pre-commit checks (pylint/yapf/isort) | ||
# env: | ||
# SKIP: insert-license | ||
# with: | ||
# extra_args: --hook-stage push --all-files | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.10" | ||
cache: "pip" # caching pip dependencies | ||
- name: install packages | ||
run: | | ||
/usr/bin/python -m pip install --upgrade pip | ||
pip install --no-deps -r images/requirements.txt | ||
# - name: ssh access | ||
# uses: lhotari/action-upterm@v1 | ||
# with: | ||
# limit-access-to-actor: true | ||
# limit-access-to-users: arashd | ||
- name: run tests | ||
run: | | ||
# Environment variables are reset in between steps. | ||
mkdir /tmp/github_testing | ||
ln -s $GITHUB_WORKSPACE /tmp/github_testing/tml | ||
export PYTHONPATH="/tmp/github_testing:$PYTHONPATH" | ||
pytest -vv |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Mac | ||
.DS_Store | ||
|
||
# Vim | ||
*.py.swp | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
.hypothesis | ||
|
||
venv |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
repos: | ||
- repo: https://github.com/pausan/cblack | ||
rev: release-22.3.0 | ||
hooks: | ||
- id: cblack | ||
name: cblack | ||
description: "Black: The uncompromising Python code formatter - 2 space indent fork" | ||
entry: cblack . -l 100 | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v2.3.0 | ||
hooks: | ||
- id: trailing-whitespace | ||
- id: end-of-file-fixer | ||
- id: check-yaml | ||
- id: check-added-large-files | ||
- id: check-merge-conflict |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
A few files here (where it is specifically noted in comments) are based on code from torchrec but | ||
adapted for our use. Torchrec license is below: | ||
|
||
|
||
BSD 3-Clause License | ||
|
||
Copyright (c) Meta Platforms, Inc. and affiliates. | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
* Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
* Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
* Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
This project open sources some of the ML models used at Twitter. | ||
|
||
Currently these are: | ||
|
||
1. The "For You" Heavy Ranker (projects/home/recap). | ||
|
||
2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387 | ||
|
||
|
||
This project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run | ||
|
||
`./images/init_venv.sh` (Linux only). | ||
|
||
The READMEs of each project contain instructions about how to run each project. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Extension of torchrec.dataset.utils.Batch to cover any dataset. | ||
""" | ||
# flake8: noqa | ||
from __future__ import annotations | ||
from typing import Dict | ||
import abc | ||
from dataclasses import dataclass | ||
import dataclasses | ||
|
||
import torch | ||
from torchrec.streamable import Pipelineable | ||
|
||
|
||
class BatchBase(Pipelineable, abc.ABC): | ||
@abc.abstractmethod | ||
def as_dict(self) -> Dict: | ||
raise NotImplementedError | ||
|
||
def to(self, device: torch.device, non_blocking: bool = False): | ||
args = {} | ||
for feature_name, feature_value in self.as_dict().items(): | ||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) | ||
return self.__class__(**args) | ||
|
||
def record_stream(self, stream: torch.cuda.streams.Stream) -> None: | ||
for feature_value in self.as_dict().values(): | ||
feature_value.record_stream(stream) | ||
|
||
def pin_memory(self): | ||
args = {} | ||
for feature_name, feature_value in self.as_dict().items(): | ||
args[feature_name] = feature_value.pin_memory() | ||
return self.__class__(**args) | ||
|
||
def __repr__(self) -> str: | ||
def obj2str(v): | ||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" | ||
|
||
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) | ||
|
||
@property | ||
def batch_size(self) -> int: | ||
for tensor in self.as_dict().values(): | ||
if tensor is None: | ||
continue | ||
if not isinstance(tensor, torch.Tensor): | ||
continue | ||
return tensor.shape[0] | ||
raise Exception("Could not determine batch size from tensors.") | ||
|
||
|
||
@dataclass | ||
class DataclassBatch(BatchBase): | ||
@classmethod | ||
def feature_names(cls): | ||
return list(cls.__dataclass_fields__.keys()) | ||
|
||
def as_dict(self): | ||
return { | ||
feature_name: getattr(self, feature_name) | ||
for feature_name in self.feature_names() | ||
if hasattr(self, feature_name) | ||
} | ||
|
||
@staticmethod | ||
def from_schema(name: str, schema): | ||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" | ||
return dataclasses.make_dataclass( | ||
cls_name=name, | ||
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], | ||
bases=(DataclassBatch,), | ||
) | ||
|
||
@staticmethod | ||
def from_fields(name: str, fields: dict): | ||
return dataclasses.make_dataclass( | ||
cls_name=name, | ||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], | ||
bases=(DataclassBatch,), | ||
) | ||
|
||
|
||
class DictionaryBatch(BatchBase, dict): | ||
def as_dict(self) -> Dict: | ||
return self |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot |
5 comments
on commit 78c3235
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
last
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is twitter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
root@321 password
k