diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index ff8b8dfa..03dfc064 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -4,14 +4,13 @@ If you are interested in contributing to PyTorch Frame, your contributions will 1. You want to implement a new feature: - In general, we accept any features as long as they fit the scope of this package. If you are unsure about this or need help on the design/implementation of your feature, post about it in an issue. -2. You want to fix a bug: +1. You want to fix a bug: - Feel free to send a Pull Request (PR) any time you encounter a bug. Please provide a clear and concise description of what the bug was. If you are unsure about if this is a bug at all or how to fix, post about it in an issue. Once you finish implementing a feature or bug-fix, please send a PR to https://github.com/pyg-team/pytorch-frame. Your PR will be merged after one or more rounds of reviews by the [pyg-team](https://github.com/pyg-team). - ## Developing PyTorch Frame To develop PyTorch Frame on your machine, here are some tips: @@ -23,26 +22,28 @@ To develop PyTorch Frame on your machine, here are some tips: print(torch.__version__) ``` -2. Uninstall all existing PyTorch Frame installations. +1. Uninstall all existing PyTorch Frame installations. It is advised to run this command repeatedly to confirm that installations across all locations are properly removed. ```bash pip uninstall pytorch_frame ``` -3. Fork and clone the PyTorch Frame repository: +1. Fork and clone the PyTorch Frame repository: ```bash git clone https://github.com//pytorch-frame.git cd pytorch-frame -5. If you already cloned PyTorch Frame from source, update it: + ``` + +1. If you already cloned PyTorch Frame from source, update it: ```bash git pull ``` -6. Install PyTorch Frame in editable mode: +1. Install PyTorch Frame in editable mode: ```bash pip install -e ".[dev,full]" @@ -51,13 +52,13 @@ To develop PyTorch Frame on your machine, here are some tips: This mode will symlink the Python files from the current local source tree into the Python install. Hence, if you modify a Python file, you do not need to re-install PyTorch Frame again. -7. Ensure that you have a working PyTorch Frame installation by running the entire test suite with +1. Ensure that you have a working PyTorch Frame installation by running the entire test suite with ```bash pytest ``` -8. Install pre-commit hooks: +1. Install pre-commit hooks: ```bash pre-commit install @@ -91,7 +92,7 @@ Everytime you send a Pull Request, your commit will be built and checked against If you do not want to format your code manually, we recommend to use [`yapf`](https://github.com/google/yapf). -2. Ensure that the entire test suite passes and that code coverage roughly stays the same. +1. Ensure that the entire test suite passes and that code coverage roughly stays the same. Please feel encouraged to provide a test with your submitted code. To test, either run @@ -101,7 +102,7 @@ Everytime you send a Pull Request, your commit will be built and checked against (which runs a set of additional but time-consuming tests) dependening on your needs. -3. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pyotrch-frame/blob/master/CHANGELOG.md?plain=1). +1. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pyotrch-frame/blob/master/CHANGELOG.md?plain=1). If multiple PRs move towards integrating a single feature, it is advised to group them together into one bullet point. ## Building Documentation @@ -109,11 +110,11 @@ Everytime you send a Pull Request, your commit will be built and checked against To build the documentation: 1. [Build and install](#developing-pytorch-frame) PyTorch Frame from source. -2. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via +1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via ```bash pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git ``` -3. Generate the documentation via: +1. Generate the documentation via: ```bash cd docs make html diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0e67fa5..d342d5ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,6 +33,26 @@ repos: name: Lint yaml args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}'] + - repo: https://github.com/asottile/pyupgrade + rev: v3.17.0 + hooks: + - id: pyupgrade + name: Upgrade Python syntax + args: [--py38-plus] + + - repo: https://github.com/PyCQA/autoflake + rev: v2.3.1 + hooks: + - id: autoflake + name: Remove unused imports and variables + args: [ + --remove-all-unused-imports, + --remove-unused-variables, + --remove-duplicate-keys, + --ignore-init-module-imports, + --in-place, + ] + - repo: https://github.com/google/yapf rev: v0.40.2 hooks: @@ -65,5 +85,15 @@ repos: hooks: - id: mypy name: Check types - additional_dependencies: [torch==2.2.0] + additional_dependencies: [torch==2.4.0] exclude: "^test/|^examples/|^benchmark/" + + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.17 + hooks: + - id: mdformat + name: Format Markdown + additional_dependencies: + - mdformat-gfm + - mdformat_frontmatter + - mdformat_footnote diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e9c335f..9bbd9232 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,8 +3,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -## [Unreleased] +## \[Unreleased\] ### Added @@ -21,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed CUDA synchronizations from `nn.LinearEmbeddingEncoder` ([#432](https://github.com/pyg-team/pytorch-frame/pull/432)) - Removed CUDA synchronizations from N/A imputation logic in `nn.StypeEncoder` ([#433](https://github.com/pyg-team/pytorch-frame/pull/433), [#434](https://github.com/pyg-team/pytorch-frame/pull/434)) -## [0.2.3] - 2024-07-08 +## \[0.2.3\] - 2024-07-08 ### Added @@ -33,8 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated `ExcelFormer` implementation and related scripts ([#391](https://github.com/pyg-team/pytorch-frame/pull/391)) - -## [0.2.2] - 2024-03-04 +## \[0.2.2\] - 2024-03-04 ### Added @@ -60,14 +58,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the split of `DataFrameTextBenchmark` ([#358](https://github.com/pyg-team/pytorch-frame/pull/358)) - Fixed empty `MultiNestedTensor` col indexing ([#355](https://github.com/pyg-team/pytorch-frame/pull/355)) -## [0.2.1] - 2024-01-16 +## \[0.2.1\] - 2024-01-16 ### Added + - Support more stypes in `LinearModelEncoder` ([#325](https://github.com/pyg-team/pytorch-frame/pull/325)) - Added `stype_encoder_dict` to some models ([#319](https://github.com/pyg-team/pytorch-frame/pull/319)) - Added `HuggingFaceDatasetDict` ([#287](https://github.com/pyg-team/pytorch-frame/pull/287)) ### Changed + - Supported decoder embedding model in `examples/transformers_text.py` ([#333](https://github.com/pyg-team/pytorch-frame/pull/333)) - Removed implicit clones in `StypeEncoder` ([#286](https://github.com/pyg-team/pytorch-frame/pull/286)) @@ -76,10 +76,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed ### Fixed + - Fixed `TimestampEncoder` not applying `CyclicEncoder` to cyclic features ([#311](https://github.com/pyg-team/pytorch-frame/pull/311)) - Fixed NaN masking in `multicateogrical` stype ([#307](https://github.com/pyg-team/pytorch-frame/pull/307)) -## [0.2.0] - 2023-12-15 +## \[0.2.0\] - 2023-12-15 ### Added @@ -104,7 +105,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Timestamp` stype ([#212](https://github.com/pyg-team/pytorch-frame/pull/212)) - Added `multicategorical` to `MultimodalTextBenchmark` ([#208](https://github.com/pyg-team/pytorch-frame/pull/208)) - Added support for saving and loading of `TensorFrame` with complex `stypes`. ([#197](https://github.com/pyg-team/pytorch-frame/pull/197)) -- Added `stype.embedding` ((#194)[https://github.com/pyg-team/pytorch-frame/pull/194]) +- Added `stype.embedding` ([#194](https://github.com/pyg-team/pytorch-frame/pull/194)) - Added `TensorFrame` concatenation of complex stypes. ([#190](https://github.com/pyg-team/pytorch-frame/pull/190)) - Added `text_tokenized` example ([#174](https://github.com/pyg-team/pytorch-frame/pull/174)) - Added Cohere embedding example ([#186](https://github.com/pyg-team/pytorch-frame/pull/186)) @@ -131,8 +132,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - No manual passing of `in_channels` to `LinearEmbeddingEncoder` for `stype.text_embedded` ([#222](https://github.com/pyg-team/pytorch-frame/pull/222)) - -## [0.1.0] - 2023-10-23 +## \[0.1.0\] - 2023-10-23 ### Added diff --git a/README.md b/README.md index 3f3184e2..b934f6c2 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,3 @@ -[testing-image]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml/badge.svg -[testing-url]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml -[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat -[contributing-url]: https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md -[slack-image]: https://img.shields.io/badge/slack-pyf-brightgreen -[slack-url]: https://data.pyg.org/slack.html -[pypi-image]: https://badge.fury.io/py/pytorch-frame.svg -[pypi-url]: https://pypi.python.org/pypi/pytorch-frame -[docs-image]: https://readthedocs.org/projects/pytorch-frame/badge/?version=latest -[docs-url]: https://pytorch-frame.readthedocs.io/en/latest -[arxiv-image]: https://img.shields.io/badge/arXiv-2404.00776-b31b1b.svg -[arxiv-url]: https://arxiv.org/abs/2404.00776 -
@@ -20,7 +7,7 @@ **A modular deep learning framework for building neural network models on heterogeneous tabular data.** --------------------------------------------------------------------------------- +______________________________________________________________________ [![arXiv][arxiv-image]][arxiv-url] [![PyPI Version][pypi-image]][pypi-url] @@ -39,7 +26,8 @@ PyTorch Frame democratizes deep learning research for tabular data, catering to 1. **Facilitate Deep Learning for Tabular Data:** Historically, tree-based models (e.g., GBDT) excelled at tabular learning but had notable limitations, such as integration difficulties with downstream models, and handling complex column types, such as texts, sequences, and embeddings. Deep tabular models are promising to resolve the limitations. We aim to facilitate deep learning research on tabular data by modularizing its implementation and supporting the diverse column types. -2. **Integrates with Diverse Model Architectures like Large Language Models:** PyTorch Frame supports integration with a variety of different architectures including LLMs. With any downloaded model or embedding API endpoint, you can encode your text data with embeddings and train it with deep learning models alongside other complex semantic types. We support the following (but not limited to): +1. **Integrates with Diverse Model Architectures like Large Language Models:** PyTorch Frame supports integration with a variety of different architectures including LLMs. With any downloaded model or embedding API endpoint, you can encode your text data with embeddings and train it with deep learning models alongside other complex semantic types. We support the following (but not limited to): +
@@ -71,27 +59,27 @@ PyTorch Frame democratizes deep learning research for tabular data, catering to
-* [Library Highlights](#library-highlights) -* [Architecture Overview](#architecture-overview) -* [Quick Tour](#quick-tour) -* [Implemented Deep Tabular Models](#implemented-deep-tabular-models) -* [Benchmark](#benchmark) -* [Installation](#installation) +- [Library Highlights](#library-highlights) +- [Architecture Overview](#architecture-overview) +- [Quick Tour](#quick-tour) +- [Implemented Deep Tabular Models](#implemented-deep-tabular-models) +- [Benchmark](#benchmark) +- [Installation](#installation) ## Library Highlights PyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for existing PyTorch users. Key features include: -* **Diverse column types**: +- **Diverse column types**: PyTorch Frame supports learning across various column types: `numerical`, `categorical`, `multicategorical`, `text_embedded`, `text_tokenized`, `timestamp`, `image_embedded`, and `embedding`. See [here](https://pytorch-frame.readthedocs.io/en/latest/handling_advanced_stypes/handle_heterogeneous_stypes.html) for the detailed tutorial. -* **Modular model design**: +- **Modular model design**: Enables modular deep learning model implementations, promoting reusability, clear coding, and experimentation flexibility. Further details in the [architecture overview](#architecture-overview). -* **Models** +- **Models** Implements many [state-of-the-art deep tabular models](#implemented-deep-tabular-models) as well as strong GBDTs (XGBoost, CatBoost, and LightGBM) with hyper-parameter tuning. -* **Datasets**: +- **Datasets**: Comes with a collection of readily-usable tabular datasets. Also supports custom datasets to solve your own problem. We [benchmark](https://github.com/pyg-team/pytorch-frame/blob/master/benchmark) deep tabular models against GBDTs. -* **PyTorch integration**: +- **PyTorch integration**: Integrates effortlessly with other PyTorch libraries, facilitating end-to-end training of PyTorch Frame with downstream PyTorch models. For example, by integrating with [PyG](https://pyg.org/), a PyTorch library for GNNs, we can perform deep learning over relational databases. Learn more in [RelBench](https://relbench.stanford.edu/) and [example code (WIP)](https://github.com/snap-stanford/relbench/blob/main/examples/gnn.py). ## Architecture Overview @@ -104,11 +92,10 @@ Models in PyTorch Frame follow a modular design of `FeatureEncoder`, `TableConv` In essence, this modular setup empowers users to effortlessly experiment with myriad architectures: -* `Materialization` handles converting the raw pandas `DataFrame` into a `TensorFrame` that is amenable to Pytorch-based training and modeling. -* `FeatureEncoder` encodes `TensorFrame` into hidden column embeddings of size `[batch_size, num_cols, channels]`. -* `TableConv` models column-wise interactions over the hidden embeddings. -* `Decoder` generates embedding/prediction per row. - +- `Materialization` handles converting the raw pandas `DataFrame` into a `TensorFrame` that is amenable to Pytorch-based training and modeling. +- `FeatureEncoder` encodes `TensorFrame` into hidden column embeddings of size `[batch_size, num_cols, channels]`. +- `TableConv` models column-wise interactions over the hidden embeddings. +- `Decoder` generates embedding/prediction per row. ## Quick Tour @@ -118,9 +105,10 @@ In this quick tour, we showcase the ease of creating and training a deep tabular As an example, we implement a simple `ExampleTransformer` following the modular architecture of Pytorch Frame. In the example below: -* `self.encoder` maps an input `TensorFrame` to an embedding of size `[batch_size, num_cols, channels]`. -* `self.convs` iteratively transforms the embedding of size `[batch_size, num_cols, channels]` into an embedding of the same size. -* `self.decoder` pools the embedding of size `[batch_size, num_cols, channels]` into `[batch_size, out_channels]`. + +- `self.encoder` maps an input `TensorFrame` to an embedding of size `[batch_size, num_cols, channels]`. +- `self.convs` iteratively transforms the embedding of size `[batch_size, num_cols, channels]` into an embedding of the same size. +- `self.decoder` pools the embedding of size `[batch_size, num_cols, channels]` into `[batch_size, out_channels]`. ```python from torch import Tensor @@ -212,48 +200,46 @@ for epoch in range(50): We list currently supported deep tabular models: -* **[Trompt](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.Trompt.html)** from Chen *et al.*: [Trompt: Towards a Better Deep Neural Network for Tabular Data](https://arxiv.org/abs/2305.18446) (ICML 2023) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/trompt.py)] -* **[FTTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.FTTransformer.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)] -* **[ResNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ResNet.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)] -* **[TabNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabNet.html)** from Arık *et al.*: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) (AAAI 2021) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tabnet.py)] -* **[ExcelFormer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ExcelFormer.html)** from Chen *et al.*: [ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data](https://arxiv.org/abs/2301.02819) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/excelformer.py)] -* **[TabTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabTransformer.html)** from Huang *et al.*: [TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/abs/2012.06678) [[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tab_transformer.py)] +- **[Trompt](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.Trompt.html)** from Chen *et al.*: [Trompt: Towards a Better Deep Neural Network for Tabular Data](https://arxiv.org/abs/2305.18446) (ICML 2023) \[[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/trompt.py)\] +- **[FTTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.FTTransformer.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) \[[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)\] +- **[ResNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ResNet.html)** from Gorishniy *et al.*: [Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959) (NeurIPS 2021) \[[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/revisiting.py)\] +- **[TabNet](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabNet.html)** from Arık *et al.*: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442) (AAAI 2021) \[[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tabnet.py)\] +- **[ExcelFormer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.ExcelFormer.html)** from Chen *et al.*: [ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data](https://arxiv.org/abs/2301.02819) \[[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/excelformer.py)\] +- **[TabTransformer](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.nn.models.TabTransformer.html)** from Huang *et al.*: [TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/abs/2012.06678) \[[**Example**](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tab_transformer.py)\] In addition, we implemented `XGBoost`, `CatBoost`, and `LightGBM` [examples](https://github.com/pyg-team/pytorch-frame/blob/master/examples/tuned_gbdt.py) with hyperparameter-tuning using [Optuna](https://optuna.org/) for users who'd like to compare their model performance with `GBDTs`. - ## Benchmark We benchmark recent tabular deep learning models against GBDTs over diverse public datasets with different sizes and task types. The following chart shows the performance of various models on small regression datasets, where the row represents the model names and the column represents dataset indices (we have 13 datasets here). For more results on classification and larger datasets, please check the [benchmark documentation](https://github.com/pyg-team/pytorch-frame/blob/master/benchmark). -| Model Name | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | -|:--------------------|:----------------|:-----------------|:----------------|:----------------|:----------------|:-----------------|:----------------|:----------------|:-------------------|:----------------|:----------------|:-------------------|:----------------| -| XGBoost | **0.250±0.000** | 0.038±0.000 | 0.187±0.000 | 0.475±0.000 | 0.328±0.000 | 0.401±0.000 | **0.249±0.000** | 0.363±0.000 | 0.904±0.000 | 0.056±0.000 | 0.820±0.000 | **0.857±0.000** | 0.418±0.000 | -| CatBoost | 0.265±0.000 | 0.062±0.000 | 0.128±0.000 | 0.336±0.000 | 0.346±0.000 | 0.443±0.000 | 0.375±0.000 | 0.273±0.000 | 0.881±0.000 | 0.040±0.000 | 0.756±0.000 | 0.876±0.000 | 0.439±0.000 | -| LightGBM | 0.253±0.000 | 0.054±0.000 | **0.112±0.000** | 0.302±0.000 | 0.325±0.000 | **0.384±0.000** | 0.295±0.000 | **0.272±0.000** | **0.877±0.000** | 0.011±0.000 | **0.702±0.000** | 0.863±0.000 | **0.395±0.000** | -| Trompt | 0.261±0.003 | **0.015±0.005** | 0.118±0.001 | **0.262±0.001** | **0.323±0.001** | 0.418±0.003 | 0.329±0.009 | 0.312±0.002 | OOM | **0.008±0.001** | 0.779±0.006 | 0.874±0.004 | 0.424±0.005 | -| ResNet | 0.288±0.006 | 0.018±0.003 | 0.124±0.001 | 0.268±0.001 | 0.335±0.001 | 0.434±0.004 | 0.325±0.012 | 0.324±0.004 | 0.895±0.005 | 0.036±0.002 | 0.794±0.006 | 0.875±0.004 | 0.468±0.004 | -| FTTransformerBucket | 0.325±0.008 | 0.096±0.005 | 0.360±0.354 | 0.284±0.005 | 0.342±0.004 | 0.441±0.003 | 0.345±0.007 | 0.339±0.003 | OOM | 0.105±0.011 | 0.807±0.010 | 0.885±0.008 | 0.468±0.006 | -| ExcelFormer | 0.262±0.004 | 0.099±0.003 | 0.128±0.000 | 0.264±0.003 | 0.331±0.003 | 0.411±0.005 | 0.298±0.012 | 0.308±0.007 | OOM | 0.011±0.001 | 0.785±0.011 | 0.890±0.003 | 0.431±0.006 | -| FTTransformer | 0.335±0.010 | 0.161±0.022 | 0.140±0.002 | 0.277±0.004 | 0.335±0.003 | 0.445±0.003 | 0.361±0.018 | 0.345±0.005 | OOM | 0.106±0.012 | 0.826±0.005 | 0.896±0.007 | 0.461±0.003 | -| TabNet | 0.279±0.003 | 0.224±0.016 | 0.141±0.010 | 0.275±0.002 | 0.348±0.003 | 0.451±0.007 | 0.355±0.030 | 0.332±0.004 | 0.992±0.182 | 0.015±0.002 | 0.805±0.014 | 0.885±0.013 | 0.544±0.011 | -| TabTransformer | 0.624±0.003 | 0.229±0.003 | 0.369±0.005 | 0.340±0.004 | 0.388±0.002 | 0.539±0.003 | 0.619±0.005 | 0.351±0.001 | 0.893±0.005 | 0.431±0.001 | 0.819±0.002 | 0.886±0.005 | 0.545±0.004 | - +| Model Name | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | +| :------------------ | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | :-------------- | +| XGBoost | **0.250±0.000** | 0.038±0.000 | 0.187±0.000 | 0.475±0.000 | 0.328±0.000 | 0.401±0.000 | **0.249±0.000** | 0.363±0.000 | 0.904±0.000 | 0.056±0.000 | 0.820±0.000 | **0.857±0.000** | 0.418±0.000 | +| CatBoost | 0.265±0.000 | 0.062±0.000 | 0.128±0.000 | 0.336±0.000 | 0.346±0.000 | 0.443±0.000 | 0.375±0.000 | 0.273±0.000 | 0.881±0.000 | 0.040±0.000 | 0.756±0.000 | 0.876±0.000 | 0.439±0.000 | +| LightGBM | 0.253±0.000 | 0.054±0.000 | **0.112±0.000** | 0.302±0.000 | 0.325±0.000 | **0.384±0.000** | 0.295±0.000 | **0.272±0.000** | **0.877±0.000** | 0.011±0.000 | **0.702±0.000** | 0.863±0.000 | **0.395±0.000** | +| Trompt | 0.261±0.003 | **0.015±0.005** | 0.118±0.001 | **0.262±0.001** | **0.323±0.001** | 0.418±0.003 | 0.329±0.009 | 0.312±0.002 | OOM | **0.008±0.001** | 0.779±0.006 | 0.874±0.004 | 0.424±0.005 | +| ResNet | 0.288±0.006 | 0.018±0.003 | 0.124±0.001 | 0.268±0.001 | 0.335±0.001 | 0.434±0.004 | 0.325±0.012 | 0.324±0.004 | 0.895±0.005 | 0.036±0.002 | 0.794±0.006 | 0.875±0.004 | 0.468±0.004 | +| FTTransformerBucket | 0.325±0.008 | 0.096±0.005 | 0.360±0.354 | 0.284±0.005 | 0.342±0.004 | 0.441±0.003 | 0.345±0.007 | 0.339±0.003 | OOM | 0.105±0.011 | 0.807±0.010 | 0.885±0.008 | 0.468±0.006 | +| ExcelFormer | 0.262±0.004 | 0.099±0.003 | 0.128±0.000 | 0.264±0.003 | 0.331±0.003 | 0.411±0.005 | 0.298±0.012 | 0.308±0.007 | OOM | 0.011±0.001 | 0.785±0.011 | 0.890±0.003 | 0.431±0.006 | +| FTTransformer | 0.335±0.010 | 0.161±0.022 | 0.140±0.002 | 0.277±0.004 | 0.335±0.003 | 0.445±0.003 | 0.361±0.018 | 0.345±0.005 | OOM | 0.106±0.012 | 0.826±0.005 | 0.896±0.007 | 0.461±0.003 | +| TabNet | 0.279±0.003 | 0.224±0.016 | 0.141±0.010 | 0.275±0.002 | 0.348±0.003 | 0.451±0.007 | 0.355±0.030 | 0.332±0.004 | 0.992±0.182 | 0.015±0.002 | 0.805±0.014 | 0.885±0.013 | 0.544±0.011 | +| TabTransformer | 0.624±0.003 | 0.229±0.003 | 0.369±0.005 | 0.340±0.004 | 0.388±0.002 | 0.539±0.003 | 0.619±0.005 | 0.351±0.001 | 0.893±0.005 | 0.431±0.001 | 0.819±0.002 | 0.886±0.005 | 0.545±0.004 | We see that some recent deep tabular models were able to achieve competitive model performance to strong GBDTs (despite being 5--100 times slower to train). Making deep tabular models even more performant with less compute is a fruitful direction for future research. We also benchmark different text encoders on a real-world tabular dataset ([Wine Reviews](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.datasets.MultimodalTextBenchmark.html#torch_frame.datasets.MultimodalTextBenchmark)) with one text column. The following table shows the performance: -| Test Acc | Method | Model Name | Source | -|:-----------|:----------------|:-----------------------------------------------------------|:--------------| -| 0.7926 | Pre-trained | sentence-transformers/all-distilroberta-v1 (125M # params) | Hugging Face | -| 0.7998 | Pre-trained | embed-english-v3.0 (dimension size: 1024) | Cohere | -| 0.8102 | Pre-trained | text-embedding-ada-002 (dimension size: 1536) | OpenAI | -| 0.8147 | Pre-trained | voyage-01 (dimension size: 1024) | Voyage AI | -| 0.8203 | Pre-trained | intfloat/e5-mistral-7b-instruct (7B # params) | Hugging Face | -| **0.8230** | LoRA Finetune | DistilBERT (66M # params) | Hugging Face | +| Test Acc | Method | Model Name | Source | +| :--------- | :------------ | :--------------------------------------------------------- | :----------- | +| 0.7926 | Pre-trained | sentence-transformers/all-distilroberta-v1 (125M # params) | Hugging Face | +| 0.7998 | Pre-trained | embed-english-v3.0 (dimension size: 1024) | Cohere | +| 0.8102 | Pre-trained | text-embedding-ada-002 (dimension size: 1536) | OpenAI | +| 0.8147 | Pre-trained | voyage-01 (dimension size: 1024) | Voyage AI | +| 0.8203 | Pre-trained | intfloat/e5-mistral-7b-instruct (7B # params) | Hugging Face | +| **0.8230** | LoRA Finetune | DistilBERT (66M # params) | Hugging Face | The benchmark script for Hugging Face text encoders is in this [file](https://github.com/pyg-team/pytorch-frame/blob/master/examples/transformers_text.py) and for the rest of text encoders is in this [file](https://github.com/pyg-team/pytorch-frame/blob/master/examples/llm_embedding.py). @@ -270,6 +256,7 @@ See [the installation guide](https://pytorch-frame.readthedocs.io/en/latest/get_ ## Cite If you use PyTorch Frame in your work, please cite our paper (Bibtex below). + ``` @article{hu2024pytorch, title={PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning}, @@ -278,3 +265,16 @@ If you use PyTorch Frame in your work, please cite our paper (Bibtex below). year={2024} } ``` + +[arxiv-image]: https://img.shields.io/badge/arXiv-2404.00776-b31b1b.svg +[arxiv-url]: https://arxiv.org/abs/2404.00776 +[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat +[contributing-url]: https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md +[docs-image]: https://readthedocs.org/projects/pytorch-frame/badge/?version=latest +[docs-url]: https://pytorch-frame.readthedocs.io/en/latest +[pypi-image]: https://badge.fury.io/py/pytorch-frame.svg +[pypi-url]: https://pypi.python.org/pypi/pytorch-frame +[slack-image]: https://img.shields.io/badge/slack-pyf-brightgreen +[slack-url]: https://data.pyg.org/slack.html +[testing-image]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml/badge.svg +[testing-url]: https://github.com/pyg-team/pytorch-frame/actions/workflows/testing.yml diff --git a/benchmark/README.md b/benchmark/README.md index c92c8c19..e7329118 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,6 +1,7 @@ ## Benchmarking model performance across diverse `DataFrameBenchmark` datasets. First install additional dependencies: + ```bash pip install optuna pip install torchmetrics @@ -10,6 +11,7 @@ pip install lightgbm ``` Then run + ```bash # Specify the model from [TabNet, FTTransformer, ResNet, MLP, TabTransformer, # Trompt, ExcelFormer, FTTransformerBucket, XGBoost, CatBoost, LightGBM] @@ -51,32 +53,33 @@ the total time spent, including [`Optuna`](https://optuna.org/)-based hyper-para For the mapping from dataset `idx` into the actual dataset object, please see the [documentation](https://pytorch-frame.readthedocs.io/en/latest/generated/torch_frame.datasets.DataFrameBenchmark.html#torch_frame.datasets.DataFrameBenchmark). ### `task_type: binary_classification` + Metric: ROC-AUC, higher the better. #### `scale: small` Experimental setting: 20 Optuna search trials. 50 epochs of training. -| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | dataset_13 | -|:--------------------|:-----------------------------|:-------------------------------|:--------------------------------|:-----------------------------|:-------------------------------|:------------------------------|:-----------------------------|:------------------------------|:------------------------------|:-----------------------------|:-----------------------------|:-------------------------------|:-------------------------------|:-------------------------------| -| XGBoost | **0.931±0.000 (41s)** | **1.000±0.000 (3s)** | 0.940±0.000 (389s) | **0.947±0.000 (42s)** | 0.885±0.000 (109s) | 0.966±0.000 (14s) | **0.862±0.000 (10s)** | **0.779±0.000 (79s)** | **0.984±0.000 (376s)** | 0.714±0.000 (10s) | 0.787±0.000 (9s) | 0.951±0.000 (103s) | **0.999±0.000 (434s)** | 0.925±0.000 (848s) | -| CatBoost | 0.930±0.000 (152s) | **1.000±0.000 (9s)** | 0.938±0.000 (164s) | 0.924±0.000 (29s) | 0.881±0.000 (27s) | 0.963±0.000 (48s) | 0.861±0.000 (12s) | 0.772±0.000 (10s) | 0.930±0.000 (91s) | 0.628±0.000 (10s) | **0.796±0.000 (15s)** | 0.948±0.000 (46s) | **0.998±0.000 (38s)** | 0.926±0.000 (115s) | -| LightGBM | **0.931±0.000 (15s)** | 0.999±0.000 (1s) | 0.943±0.000 (23s) | 0.943±0.000 (14s) | **0.887±0.000 (5s)** | **0.972±0.000 (11s)** | **0.862±0.000 (6s)** | 0.774±0.000 (3s) | 0.979±0.000 (41s) | **0.732±0.000** (13s) | 0.787±0.000 (3s) | 0.951±0.000 (13s) | 0.999±0.000 (10s) | **0.927±0.000 (24s)** | -| Trompt | 0.919±0.000 (9627s) | **1.000±0.000 (5341s)** | **0.945±0.000 (14679s)** | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | **0.952±0.001 (4876s)** | **1.000±0.000 (3558s)** | 0.916±0.001 (30002s) | -| ResNet | 0.917±0.000 (615s) | **1.000±0.000 (71s)** | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | **0.794±0.002 (76s)** | 0.946±0.002 (145s) | **1.000±0.000 (93s)** | 0.911±0.001 (880s) | -| MLP | 0.913±0.001 (112s) | **1.000±0.000 (45s)** | 0.934±0.001 (274s) | 0.938±0.001 (66s) | 0.863±0.002 (61s) | 0.953±0.000 (92s) | 0.830±0.001 (68s) | 0.769±0.002 (56s) | 0.903±0.002 (159s) | 0.666±0.015 (58s) | 0.789±0.001 (48s) | 0.940±0.002 (107s) | **1.000±0.000 (48s)** | 0.910±0.001 (149s) | -| FTTransformerBucket | 0.915±0.001 (690s) | **0.999±0.001 (354s)** | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | **0.999±0.000 (634s)** | 0.913±0.001 (1164s) | -| ExcelFormer | 0.918±0.001 (1587s) | **1.000±0.000 (634s)** | 0.939±0.001 (1827s) | 0.939±0.002 (378s) | 0.883±0.001 (289s) | 0.969±0.000 (678s) | 0.833±0.011 (435s) | **0.780±0.002 (938s)** | 0.940±0.003 (919s) | 0.670±0.017 (464s) | 0.794±0.003 (683s) | 0.950±0.001 (405s) | **0.999±0.000 (1169s)** | 0.919±0.001 (1798s) | -| FTTransformer | 0.918±0.001 (871s) | **1.000±0.000 (571s)** | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | **1.000±0.000 (713s)** | 0.912±0.000 (1855s) | -| TabNet | 0.911±0.001 (150s) | **1.000±0.000 (35s)** | 0.931±0.005 (254s) | 0.937±0.003 (125s) | 0.864±0.002 (52s) | 0.944±0.001 (116s) | 0.828±0.001 (79s) | 0.771±0.005 (93s) | 0.913±0.005 (177s) | 0.606±0.014 (65s) | 0.790±0.003 (41s) | 0.936±0.003 (104s) | **1.000±0.000 (64s)** | 0.910±0.001 (294s) | -| TabTransformer | 0.910±0.001 (2044s) | **1.000±0.000 (1321s)** | 0.928±0.001 (2519s) | 0.918±0.003 (134s) | 0.829±0.002 (64s) | 0.928±0.001 (105s) | 0.816±0.002 (99s) | 0.757±0.003 (645s) | 0.885±0.001 (1167s) | 0.652±0.006 (282s) | 0.780±0.002 (112s) | 0.937±0.001 (117s) | 0.996±0.000 (76s) | 0.905±0.001 (2283s) | +| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | dataset_13 | +| :------------------ | :-------------------- | :---------------------- | :----------------------- | :-------------------- | :------------------- | :-------------------- | :-------------------- | :--------------------- | :--------------------- | :-------------------- | :-------------------- | :---------------------- | :---------------------- | :-------------------- | +| XGBoost | **0.931±0.000 (41s)** | **1.000±0.000 (3s)** | 0.940±0.000 (389s) | **0.947±0.000 (42s)** | 0.885±0.000 (109s) | 0.966±0.000 (14s) | **0.862±0.000 (10s)** | **0.779±0.000 (79s)** | **0.984±0.000 (376s)** | 0.714±0.000 (10s) | 0.787±0.000 (9s) | 0.951±0.000 (103s) | **0.999±0.000 (434s)** | 0.925±0.000 (848s) | +| CatBoost | 0.930±0.000 (152s) | **1.000±0.000 (9s)** | 0.938±0.000 (164s) | 0.924±0.000 (29s) | 0.881±0.000 (27s) | 0.963±0.000 (48s) | 0.861±0.000 (12s) | 0.772±0.000 (10s) | 0.930±0.000 (91s) | 0.628±0.000 (10s) | **0.796±0.000 (15s)** | 0.948±0.000 (46s) | **0.998±0.000 (38s)** | 0.926±0.000 (115s) | +| LightGBM | **0.931±0.000 (15s)** | 0.999±0.000 (1s) | 0.943±0.000 (23s) | 0.943±0.000 (14s) | **0.887±0.000 (5s)** | **0.972±0.000 (11s)** | **0.862±0.000 (6s)** | 0.774±0.000 (3s) | 0.979±0.000 (41s) | **0.732±0.000** (13s) | 0.787±0.000 (3s) | 0.951±0.000 (13s) | 0.999±0.000 (10s) | **0.927±0.000 (24s)** | +| Trompt | 0.919±0.000 (9627s) | **1.000±0.000 (5341s)** | **0.945±0.000 (14679s)** | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | **0.952±0.001 (4876s)** | **1.000±0.000 (3558s)** | 0.916±0.001 (30002s) | +| ResNet | 0.917±0.000 (615s) | **1.000±0.000 (71s)** | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | **0.794±0.002 (76s)** | 0.946±0.002 (145s) | **1.000±0.000 (93s)** | 0.911±0.001 (880s) | +| MLP | 0.913±0.001 (112s) | **1.000±0.000 (45s)** | 0.934±0.001 (274s) | 0.938±0.001 (66s) | 0.863±0.002 (61s) | 0.953±0.000 (92s) | 0.830±0.001 (68s) | 0.769±0.002 (56s) | 0.903±0.002 (159s) | 0.666±0.015 (58s) | 0.789±0.001 (48s) | 0.940±0.002 (107s) | **1.000±0.000 (48s)** | 0.910±0.001 (149s) | +| FTTransformerBucket | 0.915±0.001 (690s) | **0.999±0.001 (354s)** | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | **0.999±0.000 (634s)** | 0.913±0.001 (1164s) | +| ExcelFormer | 0.918±0.001 (1587s) | **1.000±0.000 (634s)** | 0.939±0.001 (1827s) | 0.939±0.002 (378s) | 0.883±0.001 (289s) | 0.969±0.000 (678s) | 0.833±0.011 (435s) | **0.780±0.002 (938s)** | 0.940±0.003 (919s) | 0.670±0.017 (464s) | 0.794±0.003 (683s) | 0.950±0.001 (405s) | **0.999±0.000 (1169s)** | 0.919±0.001 (1798s) | +| FTTransformer | 0.918±0.001 (871s) | **1.000±0.000 (571s)** | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | **1.000±0.000 (713s)** | 0.912±0.000 (1855s) | +| TabNet | 0.911±0.001 (150s) | **1.000±0.000 (35s)** | 0.931±0.005 (254s) | 0.937±0.003 (125s) | 0.864±0.002 (52s) | 0.944±0.001 (116s) | 0.828±0.001 (79s) | 0.771±0.005 (93s) | 0.913±0.005 (177s) | 0.606±0.014 (65s) | 0.790±0.003 (41s) | 0.936±0.003 (104s) | **1.000±0.000 (64s)** | 0.910±0.001 (294s) | +| TabTransformer | 0.910±0.001 (2044s) | **1.000±0.000 (1321s)** | 0.928±0.001 (2519s) | 0.918±0.003 (134s) | 0.829±0.002 (64s) | 0.928±0.001 (105s) | 0.816±0.002 (99s) | 0.757±0.003 (645s) | 0.885±0.001 (1167s) | 0.652±0.006 (282s) | 0.780±0.002 (112s) | 0.937±0.001 (117s) | 0.996±0.000 (76s) | 0.905±0.001 (2283s) | #### `scale: medium` Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 5 Optuna search trials and 25 epochs training for deep learning models. -| | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -|:--------------------|:----------------------|:------------------------|:------------------------|:------------------------|:----------------------|:-------------------------|:-------------------------|:----------------------|:------------------------| +| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | +| :------------------ | :-------------------- | :---------------------- | :---------------------- | :---------------------- | :-------------------- | :----------------------- | :----------------------- | :-------------------- | :---------------------- | | XGBoost | 0.594±0.000 (466s) | **0.955±0.000 (6340s)** | **0.653±0.000 (19s)** | **0.986±0.000 (195s)** | 0.721±0.000 (62s) | **0.998±0.000 (70626s)** | 0.868±0.000 (159s) | 0.888±0.000 (2945s) | 0.803±0.000 (371s) | | CatBoost | 0.631±0.000 (1201s) | **0.956±0.000 (2963s)** | 0.649±0.000 (26s) | **0.986±0.000 (352s)** | 0.719±0.000 (244s) | 0.987±0.000 (2561s) | 0.863±0.000 (212s) | 0.896±0.000 (740s) | 0.803±0.000 (140s) | | LightGBM | **0.639±0.000 (49s)** | 0.955±0.000 (126s) | 0.652±0.000 (7s) | **0.986±0.000 (99s)** | **0.723±0.000 (16s)** | 0.997±0.000 (172s) | 0.881±0.000 (83s) | **0.914±0.000 (86s)** | **0.809±0.000 (76s)** | @@ -91,11 +94,10 @@ Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM #### `scale: large` - Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models. -| | 0 | -|:--------------------|:------------------------| +| | dataset_0 | +| :------------------ | :---------------------- | | XGBoost | 0.792±0.000 (28889s) | | CatBoost | 0.788±0.000 (240s) | | LightGBM | 0.831±0.000 (167s) | @@ -109,75 +111,75 @@ Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM | TabTransformer | 0.790±0.002 (457s) | ### `task_type: regression` + Metric: RMSE, lower the better. #### `scale: small` Experimental setting: 20 Optuna search trials. 50 epochs of training. -| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | -|:--------------------|:------------------------------|:-------------------------------|:-------------------------------|:--------------------------------|:-------------------------------|:------------------------------|:------------------------------|:------------------------------|:-------------------------------|:-------------------------------|:------------------------------|:----------------------------|:-------------------------------| -| XGBoost | **0.250±0.000 (22s)** | 0.038±0.000 (1011s) | 0.187±0.000 (19s) | 0.475±0.000 (439s) | 0.328±0.000 (32s) | 0.401±0.000 (375s) | **0.249±0.000 (340s)** | 0.363±0.000 (378s) | 0.904±0.000 (2400s) | 0.056±0.000 (250s) | 0.820±0.000 (721s) | **0.857±0.000 (487s)** | 0.418±0.000 (46s) | -| CatBoost | 0.265±0.000 (116s) | 0.062±0.000 (129s) | 0.128±0.000 (97s) | 0.336±0.000 (103s) | 0.346±0.000 (110s) | 0.443±0.000 (97s) | 0.375±0.000 (46s) | 0.273±0.000 (693s) | 0.881±0.000 (660s) | 0.040±0.000 (80s) | 0.756±0.000 (44s) | 0.876±0.000 (110s) | 0.439±0.000 (101s) | -| LightGBM | 0.253±0.000 (38s) | 0.054±0.000 (24s) | **0.112±0.000 (10s)** | 0.302±0.000 (30s) | 0.325±0.000 (30s) | **0.384±0.000 (23s)** | 0.295±0.000 (15s) | **0.272±0.000 (26s)** | **0.877±0.000 (16s)** | 0.011±0.000 (12s) | **0.702±0.000 (13s)** | 0.863±0.000 (5s) | **0.395±0.000 (40s)** | -| Trompt | 0.261±0.003 (8390s) | **0.015±0.005 (3792s)** | 0.118±0.001 (3836s) | **0.262±0.001 (10037s)** | **0.323±0.001 (9255s)** | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | **0.008±0.001 (1889s)** | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) | -| ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) | -| MLP | 0.300±0.002 (108s) | 0.141±0.015 (76s) | 0.125±0.001 (44s) | 0.272±0.002 (69s) | 0.348±0.001 (103s) | 0.435±0.002 (33s) | 0.331±0.008 (43s) | 0.380±0.004 (125s) | 0.893±0.002 (69s) | 0.017±0.001 (48s) | 0.784±0.007 (29s) | 0.881±0.005 (30s) | 0.467±0.003 (92s) | -| FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) | -| ExcelFormer | 0.262±0.004 (770s) | 0.099±0.003 (490s) | 0.128±0.000 (362s) | 0.264±0.003 (796s) | 0.331±0.003 (1121s) | 0.411±0.005 (469s) | 0.298±0.012 (222s) | 0.308±0.007 (5522s) | OOM | 0.011±0.001 (227) | 0.785±0.011 (314s) | 0.890±0.003 (1186s) | 0.431±0.006 (682s) | -| FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) | -| TabNet | 0.279±0.003 (68s) | 0.224±0.016 (53s) | 0.141±0.010 (34s) | 0.275±0.002 (61s) | 0.348±0.003 (110s) | 0.451±0.007 (82s) | 0.355±0.030 (49s) | 0.332±0.004 (168s) | 0.992±0.182 (53s) | 0.015±0.002 (57s) | 0.805±0.014 (27s) | 0.885±0.013 (46s) | 0.544±0.011 (112s) | -| TabTransformer | 0.624±0.003 (1225s) | 0.229±0.003 (1200s) | 0.369±0.005 (52s) | 0.340±0.004 (163s) | 0.388±0.002 (1137s) | 0.539±0.003 (100s) | 0.619±0.005 (73s) | 0.351±0.001 (125s) | 0.893±0.005 (389s) | 0.431±0.001 (489s) | 0.819±0.002 (52s) | 0.886±0.005 (46s) | 0.545±0.004 (95s) | +| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | +| :------------------ | :-------------------- | :---------------------- | :-------------------- | :----------------------- | :---------------------- | :-------------------- | :--------------------- | :-------------------- | :-------------------- | :---------------------- | :-------------------- | :--------------------- | :-------------------- | +| XGBoost | **0.250±0.000 (22s)** | 0.038±0.000 (1011s) | 0.187±0.000 (19s) | 0.475±0.000 (439s) | 0.328±0.000 (32s) | 0.401±0.000 (375s) | **0.249±0.000 (340s)** | 0.363±0.000 (378s) | 0.904±0.000 (2400s) | 0.056±0.000 (250s) | 0.820±0.000 (721s) | **0.857±0.000 (487s)** | 0.418±0.000 (46s) | +| CatBoost | 0.265±0.000 (116s) | 0.062±0.000 (129s) | 0.128±0.000 (97s) | 0.336±0.000 (103s) | 0.346±0.000 (110s) | 0.443±0.000 (97s) | 0.375±0.000 (46s) | 0.273±0.000 (693s) | 0.881±0.000 (660s) | 0.040±0.000 (80s) | 0.756±0.000 (44s) | 0.876±0.000 (110s) | 0.439±0.000 (101s) | +| LightGBM | 0.253±0.000 (38s) | 0.054±0.000 (24s) | **0.112±0.000 (10s)** | 0.302±0.000 (30s) | 0.325±0.000 (30s) | **0.384±0.000 (23s)** | 0.295±0.000 (15s) | **0.272±0.000 (26s)** | **0.877±0.000 (16s)** | 0.011±0.000 (12s) | **0.702±0.000 (13s)** | 0.863±0.000 (5s) | **0.395±0.000 (40s)** | +| Trompt | 0.261±0.003 (8390s) | **0.015±0.005 (3792s)** | 0.118±0.001 (3836s) | **0.262±0.001 (10037s)** | **0.323±0.001 (9255s)** | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | **0.008±0.001 (1889s)** | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) | +| ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) | +| MLP | 0.300±0.002 (108s) | 0.141±0.015 (76s) | 0.125±0.001 (44s) | 0.272±0.002 (69s) | 0.348±0.001 (103s) | 0.435±0.002 (33s) | 0.331±0.008 (43s) | 0.380±0.004 (125s) | 0.893±0.002 (69s) | 0.017±0.001 (48s) | 0.784±0.007 (29s) | 0.881±0.005 (30s) | 0.467±0.003 (92s) | +| FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) | +| ExcelFormer | 0.262±0.004 (770s) | 0.099±0.003 (490s) | 0.128±0.000 (362s) | 0.264±0.003 (796s) | 0.331±0.003 (1121s) | 0.411±0.005 (469s) | 0.298±0.012 (222s) | 0.308±0.007 (5522s) | OOM | 0.011±0.001 (227) | 0.785±0.011 (314s) | 0.890±0.003 (1186s) | 0.431±0.006 (682s) | +| FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) | +| TabNet | 0.279±0.003 (68s) | 0.224±0.016 (53s) | 0.141±0.010 (34s) | 0.275±0.002 (61s) | 0.348±0.003 (110s) | 0.451±0.007 (82s) | 0.355±0.030 (49s) | 0.332±0.004 (168s) | 0.992±0.182 (53s) | 0.015±0.002 (57s) | 0.805±0.014 (27s) | 0.885±0.013 (46s) | 0.544±0.011 (112s) | +| TabTransformer | 0.624±0.003 (1225s) | 0.229±0.003 (1200s) | 0.369±0.005 (52s) | 0.340±0.004 (163s) | 0.388±0.002 (1137s) | 0.539±0.003 (100s) | 0.619±0.005 (73s) | 0.351±0.001 (125s) | 0.893±0.005 (389s) | 0.431±0.001 (489s) | 0.819±0.002 (52s) | 0.886±0.005 (46s) | 0.545±0.004 (95s) | #### `scale: medium` Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna search trials and 25 epochs training for deep learning models. -| | 0 | 1 | 2 | 3 | 4 | 5 | -|:--------------------|:-------------------------|:-------------------------|:-----------------------|:------------------------|:-----------------------|:---------------------| -| XGBoost | 0.663±0.000 (18528s) | **0.014±0.000 (380s)** | 0.089±0.000 (2441s) | 0.140±0.000 (1632s) | 0.539±0.000 (22047s) | 0.900±0.000 (1420s) | -| CatBoost | 0.669±0.000 (2037s) | 0.018±0.000 (649s) | 0.092±0.000 (391s) | 0.145±0.000 (271s) | 0.549±0.000 (1347s) | 0.898±0.000 (122s) | -| LightGBM | **0.660±0.000 (199s)** | 0.015±0.000 (86s) | **0.085±0.000 (39s)** | 0.141±0.000 (35s) | **0.524±0.000 (148s)** | **0.895±0.000 (7s)** | -| Trompt | OOM | **0.014±0.000 (19976s)** | 0.092±0.001 (4060s) | **0.140±0.000 (3487s)** | 0.537±0.000 (26520s) | 0.901±0.000 (2333s) | -| ResNet | 0.676±0.000 (894s) | 0.016±0.000 (548s) | 0.101±0.001 (176s) | 0.147±0.000 (503s) | 0.555±0.003 (1121s) | 0.903±0.000 (116s) | -| MLP | 0.680±0.001 (907s) | 0.016±0.000 (1015s) | 0.105±0.000 (254s) | **0.140±0.000 (313s)** | 0.558±0.001 (1756s) | 0.905±0.001 (240s) | -| FTTransformerBucket | 0.738±0.029 (17223s) | 0.023±0.000 (2573s) | 0.113±0.002 (645s) | 0.147±0.000 (970s) | 0.545±0.000 (3009s) | 0.908±0.000 (360s) | -| ExcelFormer | **0.667±0.000 (35946s)** | **0.015±0.001 (2677s)** | 0.090±0.001 (603s) | 0.142±0.000 (1162s) | 0.526±0.001 (2403s) | 0.901±0.003 (330s) | -| FTTransformer | 0.673±0.000 (18524s) | 0.056±0.003 (3348s) | 0.119±0.003 (396s) | **0.141±0.000 (1049s)** | 0.561±0.001 (2403s) | 0.907±0.002 (302s) | -| TabNet | 0.683±0.001 (521s) | 0.024±0.001 (437s) | 0.115±0.003 (72s) | **0.140±0.000 (319s)** | 0.549±0.001 (760s) | 0.899±0.001 (37s) | -| TabTransformer | OOM | 0.799±0.000 (2829s) | 0.148±0.000 (720s) | 0.708±0.000 (182s) | 0.755±0.000 (4008s) | 0.964±0.000 (599s) | - +| | dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | +| :------------------ | :----------------------- | :----------------------- | :-------------------- | :---------------------- | :--------------------- | :------------------- | +| XGBoost | 0.663±0.000 (18528s) | **0.014±0.000 (380s)** | 0.089±0.000 (2441s) | 0.140±0.000 (1632s) | 0.539±0.000 (22047s) | 0.900±0.000 (1420s) | +| CatBoost | 0.669±0.000 (2037s) | 0.018±0.000 (649s) | 0.092±0.000 (391s) | 0.145±0.000 (271s) | 0.549±0.000 (1347s) | 0.898±0.000 (122s) | +| LightGBM | **0.660±0.000 (199s)** | 0.015±0.000 (86s) | **0.085±0.000 (39s)** | 0.141±0.000 (35s) | **0.524±0.000 (148s)** | **0.895±0.000 (7s)** | +| Trompt | OOM | **0.014±0.000 (19976s)** | 0.092±0.001 (4060s) | **0.140±0.000 (3487s)** | 0.537±0.000 (26520s) | 0.901±0.000 (2333s) | +| ResNet | 0.676±0.000 (894s) | 0.016±0.000 (548s) | 0.101±0.001 (176s) | 0.147±0.000 (503s) | 0.555±0.003 (1121s) | 0.903±0.000 (116s) | +| MLP | 0.680±0.001 (907s) | 0.016±0.000 (1015s) | 0.105±0.000 (254s) | **0.140±0.000 (313s)** | 0.558±0.001 (1756s) | 0.905±0.001 (240s) | +| FTTransformerBucket | 0.738±0.029 (17223s) | 0.023±0.000 (2573s) | 0.113±0.002 (645s) | 0.147±0.000 (970s) | 0.545±0.000 (3009s) | 0.908±0.000 (360s) | +| ExcelFormer | **0.667±0.000 (35946s)** | **0.015±0.001 (2677s)** | 0.090±0.001 (603s) | 0.142±0.000 (1162s) | 0.526±0.001 (2403s) | 0.901±0.003 (330s) | +| FTTransformer | 0.673±0.000 (18524s) | 0.056±0.003 (3348s) | 0.119±0.003 (396s) | **0.141±0.000 (1049s)** | 0.561±0.001 (2403s) | 0.907±0.002 (302s) | +| TabNet | 0.683±0.001 (521s) | 0.024±0.001 (437s) | 0.115±0.003 (72s) | **0.140±0.000 (319s)** | 0.549±0.001 (760s) | 0.899±0.001 (37s) | +| TabTransformer | OOM | 0.799±0.000 (2829s) | 0.148±0.000 (720s) | 0.708±0.000 (182s) | 0.755±0.000 (4008s) | 0.964±0.000 (599s) | #### `scale: large` Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models. -| | 0 | -|:--------------------|:------------------------| -| XGBoost | 0.966±0.000 (19327s) | -| CatBoost | 0.971±0.000 (223s) | -| LightGBM | **0.965±0.000 (67s)** | -| Trompt | 0.970±0.000 (12358s) | -| ResNet | 0.970±0.000 (672s) | -| MLP | 0.973±0.000 (223s) | -| FTTransformerBucket | 0.970±0.000 (2071s) | -| ExcelFormer | 0.969±0.000 (1785s) | -| FTTransformer | 0.971±0.000 (2918s) | -| TabNet | 0.970±0.000 (323s) | -| TabTransformer | 0.984±0.000 (318s) | +| | dataset_0 | +| :------------------ | :-------------------- | +| XGBoost | 0.966±0.000 (19327s) | +| CatBoost | 0.971±0.000 (223s) | +| LightGBM | **0.965±0.000 (67s)** | +| Trompt | 0.970±0.000 (12358s) | +| ResNet | 0.970±0.000 (672s) | +| MLP | 0.973±0.000 (223s) | +| FTTransformerBucket | 0.970±0.000 (2071s) | +| ExcelFormer | 0.969±0.000 (1785s) | +| FTTransformer | 0.971±0.000 (2918s) | +| TabNet | 0.970±0.000 (323s) | +| TabTransformer | 0.984±0.000 (318s) | ### `task_type: multiclass_classification` -Metric: Accuracy, the higher the better. +Metric: Accuracy, the higher the better. #### `scale: medium` Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 5 Optuna search trials and 25 epochs training for deep learning models. -*: Too slow which takes more than a day for a single trial. +\*: Too slow which takes more than a day for a single trial. -| | 0 | 1 | 2 | -|:--------------------|:-----------------------|:------------------------|:------------------------| +| | dataset_0 | dataset_1 | dataset_2 | +| :------------------ | :--------------------- | :---------------------- | :---------------------- | | XGBoost | Too slow\* | Too slow\* | Too slow\* | | CatBoost | Too slow\* | Too slow\* | Too slow\* | | LightGBM | Too slow\* | Too slow\* | Too slow\* | diff --git a/benchmark/encoder/README.md b/benchmark/encoder/README.md index 2d0434fa..51494086 100644 --- a/benchmark/encoder/README.md +++ b/benchmark/encoder/README.md @@ -3,30 +3,35 @@ ## Usage Exemplary command: + ``` python encoder_benchmark.py --stype-kv categorical embedding --stype-kv numerical linear ``` + It will create a dataset that will contain categorical and numerical columns and will use for them embedding and linear encoders, respectively. Arguments: -**--stype-kv**: Specify the stype(s) and corresponding encoder(s) to run. -**--num-rows**: The number of rows in the dataset (default is *8192*). -**--out-channels**: The number of output channels (default is *128*). -**--with-nan**: If specified, the dataset will include NaN values. -**--runs**: The number of runs for the benchmark (default is *1000*). -**--warmup-size**: The size of the warmup stage (default is *200()). -**--torch-profile**: If specified, torch profiling will be enabled. -**--line-profile**: If specified, line profiling will be enabled. -**--line-profile-level**: The level of line profiling (default is *'encode_forward'*). -**--device**: The device to run the benchmark on (default is *'cpu'*). + +- **--stype-kv**: Specify the stype(s) and corresponding encoder(s) to run. +- **--num-rows**: The number of rows in the dataset (default is `8192`). +- **--out-channels**: The number of output channels (default is `128`). +- **--with-nan**: If specified, the dataset will include NaN values. +- **--runs**: The number of runs for the benchmark (default is `1000`). +- **--warmup-size**: The size of the warmup stage (default is `200`). +- **--torch-profile**: If specified, torch profiling will be enabled. +- **--line-profile**: If specified, line profiling will be enabled. +- **--line-profile-level**: The level of line profiling (default is `'encode_forward'`). +- **--device**: The device to run the benchmark on (default is `'cpu'`). No matter if any profiler is used, benchmark always outputs a latency (single run execution time), e.g.: + ``` Latency: 0.034277s ``` Torch profiler produces a table of operations sorted by execution time, e.g.: + ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls @@ -45,6 +50,7 @@ Self CPU time total: 4.268s ``` Line profiler shows how many percent was spent on each of method lines, e.g.: + ``` Total time: 1.03661 s File: {PF_BASE_PATH}/pytorch-frame/torch_frame/nn/encoder/stype_encoder.py diff --git a/docs/README.md b/docs/README.md index 7d3eda98..09282686 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,11 +3,11 @@ To build the documentation: 1. [Build and install](https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md) PyTorch Frame from source. -2. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via +1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via ``` pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git ``` -3. Generate the documentation file via: +1. Generate the documentation file via: ``` cd docs make html diff --git a/examples/transformers_text.py b/examples/transformers_text.py index bd75a0e8..67f3c5a8 100644 --- a/examples/transformers_text.py +++ b/examples/transformers_text.py @@ -2,7 +2,6 @@ import argparse import os.path as osp -from typing import List import torch import torch.nn.functional as F @@ -121,7 +120,7 @@ def __init__(self, model: str, device: torch.device): self.model = AutoModel.from_pretrained(model).to(device) self.pooling = "mean" - def __call__(self, sentences: List[str]) -> Tensor: + def __call__(self, sentences: list[str]) -> Tensor: if self.model_name == "intfloat/e5-mistral-7b-instruct": sentences = [(f"Instruct: Retrieve relevant knowledge and " f"embeddings.\nQuery: {sentence}") @@ -226,7 +225,7 @@ def forward(self, feat: dict[str, MultiNestedTensor]) -> Tensor: # Return value has the shape [batch_size, 1, text_model_out_channels] return mean_pooling(out.last_hidden_state, mask) - def tokenize(self, sentences: List[str]) -> TextTokenizationOutputs: + def tokenize(self, sentences: list[str]) -> TextTokenizationOutputs: # Tokenize batches of sentences return self.tokenizer(sentences, truncation=True, padding=True, return_tensors='pt') diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 02102c3a..fd5614bd 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -5,7 +5,7 @@ import os.path as osp from abc import ABC from collections import defaultdict -from typing import Any, Dict +from typing import Any import pandas as pd import torch @@ -437,7 +437,7 @@ def canonicalize_and_validate_col_to_pattern( self, col_to_pattern: Any, col_to_pattern_name: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: canonical_col_to_pattern = canonicalize_col_to_pattern( col_to_pattern_name=col_to_pattern_name, col_to_pattern=col_to_pattern, diff --git a/torch_frame/data/multi_embedding_tensor.py b/torch_frame/data/multi_embedding_tensor.py index b2031b18..1fe9d844 100644 --- a/torch_frame/data/multi_embedding_tensor.py +++ b/torch_frame/data/multi_embedding_tensor.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence, Union +from typing import Sequence import torch from torch import Tensor @@ -199,7 +199,7 @@ def _single_index_select( def fillna_col( self, col_index: int, - fill_value: Union[int, float, Tensor], + fill_value: int | float | Tensor, ) -> None: values_index = slice(self.offset[col_index], self.offset[col_index + 1]) diff --git a/torch_frame/data/multi_nested_tensor.py b/torch_frame/data/multi_nested_tensor.py index aa2789e6..91cf23ce 100644 --- a/torch_frame/data/multi_nested_tensor.py +++ b/torch_frame/data/multi_nested_tensor.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence, Union, cast +from typing import Sequence, cast import torch from torch import Tensor @@ -268,7 +268,7 @@ def _single_index_select(self, index: int, dim: int) -> MultiNestedTensor: def fillna_col( self, col_index: int, - fill_value: Union[int, float, Tensor], + fill_value: int | float | Tensor, ) -> None: start_idx = torch.arange( col_index, diff --git a/torch_frame/data/multi_tensor.py b/torch_frame/data/multi_tensor.py index 0e0bf147..ba2dd309 100644 --- a/torch_frame/data/multi_tensor.py +++ b/torch_frame/data/multi_tensor.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -from typing import Any, Callable, Sequence, TypeVar, Union +from typing import Any, Callable, Sequence, TypeVar import torch from torch import Tensor @@ -336,7 +336,7 @@ def _single_index_select(self, index: int, dim: int) -> _MultiTensor: def fillna_col( self, col_index: int, - fill_value: Union[int, float, Tensor], + fill_value: int | float | Tensor, ): """Fill the :obj:`index`-th column in :obj:`MultiTensor` with fill_value in-place. diff --git a/torch_frame/nn/conv/ft_transformer_convs.py b/torch_frame/nn/conv/ft_transformer_convs.py index 4b837bcb..21b4aad3 100644 --- a/torch_frame/nn/conv/ft_transformer_convs.py +++ b/torch_frame/nn/conv/ft_transformer_convs.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional, Tuple - import torch from torch import Tensor from torch.nn import ( @@ -38,7 +36,7 @@ class FTTransformerConvs(TableConv): def __init__( self, channels: int, - feedforward_channels: Optional[int] = None, + feedforward_channels: int | None = None, # Arguments for Transformer num_layers: int = 3, nhead: int = 8, @@ -70,7 +68,7 @@ def reset_parameters(self): if p.dim() > 1: torch.nn.init.xavier_uniform_(p) - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: r"""CLS-token augmented Transformer convolution. Args: diff --git a/torch_frame/typing.py b/torch_frame/typing.py index c440a803..a2e49159 100644 --- a/torch_frame/typing.py +++ b/torch_frame/typing.py @@ -25,7 +25,7 @@ class Metric(Enum): MAE = 'mae' R2 = 'r2' - def supports_task_type(self, task_type: 'TaskType') -> bool: + def supports_task_type(self, task_type: TaskType) -> bool: return self in task_type.supported_metrics