diff --git a/CHANGELOG.md b/CHANGELOG.md index 0aceabce1..f124d4494 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `MovieLens 1M` dataset ([#397](https://github.com/pyg-team/pytorch-frame/pull/397)) - Added light-weight MLP ([#372](https://github.com/pyg-team/pytorch-frame/pull/372)) +- Added an inherited class from skorch.NeuralNet that is compatible with PyTorch Frame ([#375](https://github.com/pyg-team/pytorch-frame/pull/375)) - Added R^2 metric ([#403](https://github.com/pyg-team/pytorch-frame/pull/403)) ### Changed diff --git a/examples/sklearn_api.py b/examples/sklearn_api.py new file mode 100644 index 000000000..6b195fc86 --- /dev/null +++ b/examples/sklearn_api.py @@ -0,0 +1,54 @@ +from typing import Any + +import torch.nn as nn +from sklearn.datasets import load_diabetes +from sklearn.metrics import mean_squared_error +from sklearn.model_selection import train_test_split + +from torch_frame import stype +from torch_frame.data.stats import StatType +from torch_frame.nn import Trompt +from torch_frame.nn.models.trompt import Trompt +from torch_frame.utils.skorch import NeuralNetPytorchFrame + +# load the diabetes dataset +X, y = load_diabetes(return_X_y=True, as_frame=True) + +# split the data into training and testing sets +X_train, X_test, y_train, y_test = train_test_split(X, y) + + +# define the function to get the module +def get_module(col_stats: dict[str, dict[StatType, Any]], + col_names_dict: dict[stype, list[str]]) -> Trompt: + channels = 8 + out_channels = 1 + num_prompts = 2 + num_layers = 3 + return Trompt(channels=channels, out_channels=out_channels, + num_prompts=num_prompts, num_layers=num_layers, + col_stats=col_stats, col_names_dict=col_names_dict, + stype_encoder_dicts=None) + + +# wrap the function in a NeuralNetPytorchFrame +# NeuralNetClassifierPytorchFrame and NeuralNetBinaryClassifierPytorchFrame +# are also available +net = NeuralNetPytorchFrame( + module=get_module, + criterion=nn.MSELoss(), + max_epochs=10, + verbose=1, + lr=0.0001, + batch_size=30, +) + +# fit the model +net.fit(X_train, y_train) + +# predict on the test set +y_pred = net.predict(X_test) + +# calculate the mean squared error +mse = mean_squared_error(y_test, y_pred) +print(mse) diff --git a/pyproject.toml b/pyproject.toml index a611688ec..b5150ed63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ full=[ "lightgbm", "datasets", "torchmetrics", + "skorch", ] [project.urls] diff --git a/test/utils/test_skorch.py b/test/utils/test_skorch.py new file mode 100644 index 000000000..391bb210a --- /dev/null +++ b/test/utils/test_skorch.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd +import pytest +import torch +import torch.nn as nn +from sklearn.datasets import load_diabetes, load_iris +from sklearn.metrics import accuracy_score, mean_squared_error +from sklearn.model_selection import train_test_split + +from torch_frame import TaskType, stype +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_frame.data.dataset import Dataset +from torch_frame.data.stats import StatType +from torch_frame.datasets.fake import FakeDataset +from torch_frame.nn.models.mlp import MLP +from torch_frame.testing.text_embedder import HashTextEmbedder +from torch_frame.utils.skorch import ( + NeuralNetBinaryClassifierPytorchFrame, + NeuralNetClassifierPytorchFrame, + NeuralNetPytorchFrame, +) + + +class EnsureDtypeLoss(nn.Module): + def __init__(self, loss: nn.Module, dtype_input: torch.dtype = torch.float, + dtype_target: torch.dtype = torch.float): + super().__init__() + self.loss = loss + self.dtype_input = dtype_input + self.dtype_target = dtype_target + + def forward(self, input, target): + return self.loss( + input.to(dtype=self.dtype_input).squeeze(), + target.to(dtype=self.dtype_target).squeeze()) + + +@pytest.mark.parametrize('cls', ["mlp"]) +@pytest.mark.parametrize( + 'stypes', + [ + [stype.numerical], + [stype.categorical], + # [stype.text_embedded], + # [stype.numerical, stype.numerical, stype.text_embedded], + ]) +@pytest.mark.parametrize('task_type_and_loss_cls', [ + (TaskType.REGRESSION, nn.MSELoss), + (TaskType.BINARY_CLASSIFICATION, nn.BCEWithLogitsLoss), + (TaskType.MULTICLASS_CLASSIFICATION, nn.CrossEntropyLoss), +]) +@pytest.mark.parametrize('pass_dataset', [False, True]) +@pytest.mark.parametrize('module_as_function', [False, True]) +def test_skorch_torchframe_dataset(cls, stypes, task_type_and_loss_cls, + pass_dataset: bool, + module_as_function: bool): + task_type, loss_cls = task_type_and_loss_cls + loss = loss_cls() + loss = EnsureDtypeLoss( + loss, dtype_target=torch.long + if task_type == TaskType.MULTICLASS_CLASSIFICATION else torch.float) + + # initialize dataset + dataset: Dataset = FakeDataset( + num_rows=30, + # with_nan=True, + stypes=stypes, + create_split=True, + task_type=task_type, + col_to_text_embedder_cfg=TextEmbedderConfig( + text_embedder=HashTextEmbedder(8)), + ) + dataset.materialize() + train_dataset, val_dataset, test_dataset = dataset.split() + if not pass_dataset: + df_train = pd.concat([train_dataset.df, val_dataset.df]) + X_train, y_train = df_train.drop( + columns=[dataset.target_col, dataset.split_col]), df_train[ + dataset.target_col] + df_test = test_dataset.df + X_test, _ = df_test.drop( + columns=[dataset.target_col, dataset.split_col]), df_test[ + dataset.target_col] + + # never use dataset again + # we assume that only dataframes are available + del train_dataset, val_dataset, test_dataset + + if cls == "mlp": + if module_as_function: + + def get_module(col_stats: dict[str, dict[StatType, Any]], + col_names_dict: dict[stype, list[str]]) -> MLP: + channels = 8 + out_channels = 1 + if task_type == TaskType.MULTICLASS_CLASSIFICATION: + out_channels = dataset.num_classes + num_layers = 3 + return MLP( + channels=channels, + out_channels=out_channels, + num_layers=num_layers, + col_stats=col_stats, + col_names_dict=col_names_dict, + normalization="layer_norm", + ) + + module = get_module + kwargs = {} + else: + module = MLP + kwargs = { + "channels": + 8, + "out_channels": + dataset.num_classes + if task_type == TaskType.MULTICLASS_CLASSIFICATION else 1, + "num_layers": + 3, + "normalization": + "layer_norm", + } + kwargs = {f"module__{k}": v for k, v in kwargs.items()} + else: + raise NotImplementedError + kwargs.update({ + "module": module, + "criterion": loss, + "max_epochs": 2, + "verbose": 1, + "batch_size": 3, + }) + + if task_type == TaskType.REGRESSION: + net = NeuralNetPytorchFrame(**kwargs, ) + if task_type == TaskType.MULTICLASS_CLASSIFICATION: + net = NeuralNetClassifierPytorchFrame(**kwargs, ) + elif task_type == TaskType.BINARY_CLASSIFICATION: + net = NeuralNetBinaryClassifierPytorchFrame(**kwargs, ) + + if pass_dataset: + net.fit(dataset) + _ = net.predict(test_dataset) + else: + net.fit(X_train, y_train) + _ = net.predict(X_test) + + +@pytest.mark.parametrize( + 'task_type', [TaskType.MULTICLASS_CLASSIFICATION, TaskType.REGRESSION]) +def test_sklearn_only(task_type) -> None: + if task_type == TaskType.MULTICLASS_CLASSIFICATION: + X, y = load_iris(return_X_y=True, as_frame=True) + num_classes = 3 + else: + X, y = load_diabetes(return_X_y=True, as_frame=True) + + X_train, X_test, y_train, y_test = train_test_split(X, y) + + def get_module(col_stats: dict[str, dict[StatType, Any]], + col_names_dict: dict[stype, list[str]]) -> MLP: + channels = 8 + out_channels = 1 + if task_type == TaskType.MULTICLASS_CLASSIFICATION: + out_channels = num_classes + num_layers = 3 + return MLP( + channels=channels, + out_channels=out_channels, + num_layers=num_layers, + col_stats=col_stats, + col_names_dict=col_names_dict, + normalization="layer_norm", + ) + + net = NeuralNetClassifierPytorchFrame( + module=get_module, + criterion=nn.CrossEntropyLoss() + if task_type == TaskType.MULTICLASS_CLASSIFICATION else nn.MSELoss(), + max_epochs=2, + verbose=1, + lr=0.0001, + batch_size=3, + ) + net.fit(X_train, y_train) + y_pred = net.predict(X_test) + + if task_type == TaskType.MULTICLASS_CLASSIFICATION: + assert y_pred.shape == (len(y_test), num_classes) + acc = accuracy_score(y_test, y_pred.argmax(-1)) + print(acc) + else: + assert y_pred.shape == (len(y_test), 1) + mse = mean_squared_error(y_test, y_pred) + print(mse) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index fd5614bdf..946d4e718 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -7,6 +7,7 @@ from collections import defaultdict from typing import Any +import numpy as np import pandas as pd import torch from torch import Tensor @@ -733,8 +734,8 @@ def get_split(self, split: str) -> Dataset: if split not in ["train", "val", "test"]: raise ValueError(f"The split named '{split}' is not available. " f"Needs to be either 'train', 'val', or 'test'.") - indices = self.df.index[self.df[self.split_col] == - SPLIT_TO_NUM[split]].tolist() + indices = np.where( + self.df[self.split_col] == SPLIT_TO_NUM[split])[0].tolist() return self[indices] def split(self) -> tuple[Dataset, Dataset, Dataset]: diff --git a/torch_frame/utils/skorch.py b/torch_frame/utils/skorch.py new file mode 100644 index 000000000..672bf43f8 --- /dev/null +++ b/torch_frame/utils/skorch.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +import importlib +import warnings +from functools import wraps +from typing import Any, Callable + +import numpy as np +import skorch.utils +import torch +import torch.nn as nn +from numpy.typing import ArrayLike, NDArray +from pandas import DataFrame +from sklearn.model_selection import train_test_split +from skorch import NeuralNet +from torch import Tensor + +import torch_frame +from torch_frame import stype +from torch_frame.config import ( + ImageEmbedderConfig, + TextEmbedderConfig, + TextTokenizerConfig, +) +from torch_frame.data.dataset import Dataset +from torch_frame.data.loader import DataLoader +from torch_frame.data.stats import StatType +from torch_frame.data.tensor_frame import TensorFrame +from torch_frame.typing import IndexSelectType +from torch_frame.utils import infer_df_stype + + +# TODO: make it more safe +def _patch_skorch_support_tenforframe() -> None: + """Patch skorch.utils.to_tensor to support TensorFrame + as it raises an error when TensorFrame is passed. + """ + original_to_tensor = skorch.utils.to_tensor + + @wraps(original_to_tensor) + def to_tensor(X, device, accept_sparse=False): + if isinstance(X, TensorFrame): + return X + return original_to_tensor(X, device, accept_sparse) + + skorch.utils.to_tensor = to_tensor + + importlib.reload(skorch.net) + + +_patch_skorch_support_tenforframe() + + +class NeuralNetPytorchFrameDataLoader(DataLoader): + """Custom DataLoader for NeuralNetPytorchFrame. + + Converts the index to a tensor and separates the input and target tensors. + """ + def __init__(self, dataset: Dataset | TensorFrame, *args, + device: torch.device, **kwargs): + super().__init__(dataset, *args, **kwargs) + self.device = device + + def collate_fn( # type: ignore + self, index: IndexSelectType) -> tuple[TensorFrame, Tensor | None]: + index = torch.tensor(index) + res = super().collate_fn(index).to(self.device) + return res, res.y + + +class NeuralNetPytorchFrame(NeuralNet): + def __init__( + self, + # NeuralNet parameters + module: type[nn.Module] | nn.Module + | Callable[[dict[str, dict[StatType, Any]], dict[stype, list[str]]], + nn.Module], + criterion, + optimizer=torch.optim.SGD, + lr=0.01, + max_epochs=10, + batch_size=128, + iterator_train=None, + iterator_valid=None, + dataset=None, + train_split=None, + callbacks=None, + predict_nonlinearity="auto", + warm_start=False, + verbose=1, + device="cpu", + compile=False, + use_caching="auto", + # torch_frame.Dataset parameters + col_to_stype: dict[str, torch_frame.stype] | None = None, + target_col: str | None = "target_col", + split_col: str | None = "split_col", + col_to_sep: str | None | dict[str, str | None] = None, + col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] + | TextEmbedderConfig | None = None, + col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig] + | TextTokenizerConfig | None = None, + col_to_image_embedder_cfg: dict[str, ImageEmbedderConfig] + | ImageEmbedderConfig | None = None, + col_to_time_format: str | None | dict[str, str | None] = None, + # other NeuralNet parameters + **kwargs, + ) -> None: + """`skorch.NeuralNet` with `torch_frame` support. + + Additional parameters are **ONLY** used + when creating a dummy torch_frame.data.dataset.Dataset + if pandas.DataFrame is passed as X in `fit` or `predict` methods. + + Parameters + ---------- + col_to_stype (Dict[str, torch_frame.stype]): A dictionary that maps + each column in the data frame to a semantic type. + target_col (str, optional): The column used as target. + (default: :obj:`None`) + split_col (str, optional): The column that stores the pre-defined split + information. The column should only contain :obj:`0`, :obj:`1`, or + :obj:`2`. (default: :obj:`None`). + col_to_sep (Union[str, Dict[str, Optional[str]]]): A dictionary or a + string/:obj:`None` specifying the separator/delimiter for the + multi-categorical columns. If a string/:obj:`None` is specified, + then the same separator will be used throughout all the + multi-categorical columns. Note that if :obj:`None` is specified, + it assumes a multi-category is given as a :obj:`list` of + categories. If a dictionary is given, we use a separator specified + for each column. (default: :obj:`None`) + col_to_text_embedder_cfg (TextEmbedderConfig or dict, optional): + A text embedder configuration or a dictionary of configurations + specifying :obj:`text_embedder` that embeds texts into vectors and + :obj:`batch_size` that specifies the mini-batch size for + :obj:`text_embedder`. (default: :obj:`None`) + col_to_text_tokenizer_cfg (TextTokenizerConfig or dict, optional): + A text tokenizer configuration or dictionary of configurations + specifying :obj:`text_tokenizer` that maps sentences into a + list of dictionary of tensors. Each element in the list + corresponds to each sentence, keys are input arguments to + the model such as :obj:`input_ids`, and values are tensors + such as tokens. :obj:`batch_size` specifies the mini-batch + size for :obj:`text_tokenizer`. (default: :obj:`None`) + col_to_image_embedder_cfg (ImageEmbedderConfig or dict, optional): + No documentation provided. + col_to_time_format (Union[str, Dict[str, Optional[str]]], optional): A + dictionary or a string specifying the format for the timestamp + columns. See `strfttime documentation + `_ + for more information on formats. If a string is specified, + then the same format will be used throughout all the timestamp + columns. If a dictionary is given, we use a different format + specified for each column. If not specified, pandas's internal + to_datetime function will be used to auto parse time columns. + (default: :obj:`None`) + """ + super().__init__( + module=module, + criterion=criterion, + optimizer=optimizer, + lr=lr, + max_epochs=max_epochs, + batch_size=batch_size, + iterator_train=self.iterator_train_valid, # changed + iterator_valid=self.iterator_train_valid, # changed + dataset=self.create_dataset, # changed + train_split=self.split_dataset, # changed + callbacks=callbacks, + predict_nonlinearity=predict_nonlinearity, + warm_start=warm_start, + verbose=verbose, + device=device, + compile=compile, + use_caching=use_caching, + **kwargs, + ) + # additional parameters used when creating a dummy + # torch_frame.data.dataset.Dataset + self.col_to_stype = col_to_stype + self.target_col = target_col + self.split_col = split_col + self.col_to_sep = col_to_sep + self.col_to_text_embedder_cfg = col_to_text_embedder_cfg + self.col_to_text_tokenizer_cfg = col_to_text_tokenizer_cfg + self.col_to_image_embedder_cfg = col_to_image_embedder_cfg + self.col_to_time_format = col_to_time_format + # save dataset for partial_fit + self.train_split_original = train_split or ( + lambda x: train_test_split(x, test_size=0.2)) + # 0.2 is the default test_size in train_test_split in skorch + for name, v in zip( + ["iterator_train", "iterator_valid", "dataset"], + [iterator_train, iterator_valid, dataset], + ): + if v is not None: + warnings.warn( + "NeuralNetPytorchFrame does not support" + f" specifying {name}, " + "consider overriding the methods instead", UserWarning, + stacklevel=2) + + def create_dataset(self, df: DataFrame, _: Any) -> Dataset: + # skorch API + dataset_ = Dataset( + df, + self.dataset_.col_to_stype, + split_col=self.dataset_.split_col, + target_col=self.dataset_.target_col, + col_to_sep=self.dataset_.col_to_sep, + col_to_text_embedder_cfg=self.dataset_.col_to_text_embedder_cfg, + col_to_text_tokenizer_cfg=self.dataset_.col_to_text_tokenizer_cfg, + col_to_image_embedder_cfg=self.dataset_.col_to_image_embedder_cfg, + col_to_time_format=self.dataset_.col_to_time_format, + ) + dataset_.materialize() + return dataset_ + + def split_dataset(self, + dataset: Dataset) -> tuple[TensorFrame, TensorFrame]: + # skorch API + datasets = dataset.split()[:2] + return datasets[0].tensor_frame, datasets[1].tensor_frame + + def iterator_train_valid(self, dataset: Dataset, + **kwargs: Any) -> DataLoader: + # skorch API + return NeuralNetPytorchFrameDataLoader(dataset, device=self.device, + **kwargs) + + def initialize_module(self): + # skorch API + # if module, behave like the original NeuralNet + if isinstance(self.module, nn.Module) or isinstance(self.module, type): + self.module__col_stats = self.dataset_.col_stats + self.module__col_names_dict = ( + self.dataset_.tensor_frame.col_names_dict) + return super().initialize_module() + # assume that self.module is a function + self.module_ = staticmethod(self.module).__func__( + self.dataset_.col_stats, self.dataset_.tensor_frame.col_names_dict) + return self + + def fit(self, X: Dataset | DataFrame, y: ArrayLike | None = None, + **fit_params): + if isinstance(X, DataFrame): + # create target_col if not exists + if y is not None: + X = X.copy() + X[self.target_col] = y + elif self.target_col not in X: + warnings.warn( + f"target_col {self.target_col}" + " not found in X and y is None", UserWarning, stacklevel=2) + + # create split_col if not exists + if self.split_col not in X: + if y is None: + X = X.copy() + # first split the data with the split function + X_train, X_val = self.train_split_original(X, **fit_params) + # if index is in X_train, 0, otherwise 1 + X[self.split_col] = (X.index.isin(X_train.index)).astype(int) + + # col_to_stype + col_to_stype = { + k: v + for k, v in infer_df_stype(X).items() + if k not in (self.split_col, ) + } + if self.col_to_stype is not None: + col_to_stype.update(self.col_to_stype) + + self.dataset_ = Dataset( + X, + # do not include split_col + col_to_stype=col_to_stype, + split_col=self.split_col, + target_col=self.target_col, + col_to_sep=self.col_to_sep, + col_to_text_embedder_cfg=self.col_to_text_embedder_cfg, + col_to_text_tokenizer_cfg=self.col_to_text_tokenizer_cfg, + col_to_image_embedder_cfg=self.col_to_image_embedder_cfg, + col_to_time_format=self.col_to_time_format, + ) + # materialize the dataset to add col_stats and col_names_dict + # in initialize_module() + self.dataset_.materialize() + else: + self.dataset_ = X + + return super().fit(self.dataset_.df, None, **fit_params) + + def predict(self, X: Dataset | DataFrame) -> NDArray[Any]: + if isinstance(X, DataFrame): + self.dataset_ = Dataset( + X, + { + k: v + for k, v in self.dataset_.col_to_stype.items() + if k not in (self.target_col, ) + }, + split_col=None, + target_col=None, + col_to_sep=self.col_to_sep, + col_to_text_embedder_cfg=self.col_to_text_embedder_cfg, + col_to_text_tokenizer_cfg=self.col_to_text_tokenizer_cfg, + col_to_image_embedder_cfg=self.col_to_image_embedder_cfg, + col_to_time_format=self.col_to_time_format, + ) + # no need to materialize probably + else: + self.dataset_ = X + return super().predict(self.dataset_.df) + + +# TODO: make this behave more like NeuralNetClassifier +class NeuralNetClassifierPytorchFrame(NeuralNetPytorchFrame): + def fit(self, X: Dataset | DataFrame, y: ArrayLike | None = None, + **fit_params): + fit_result = super().fit(X, y, **fit_params) + self.classes = getattr( + self, "classes", + None) or self.dataset_.df[self.dataset_.target_col].unique() + return fit_result + + +class NeuralNetBinaryClassifierPytorchFrame(NeuralNetPytorchFrame): + num_classes = np.array([0, 1])