From 0b9426f262ec64d03ee0a9c87ef8f3119672117d Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Sat, 16 Mar 2024 12:52:07 +0900 Subject: [PATCH] style: format code --- torch_frame/utils/skorch.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/torch_frame/utils/skorch.py b/torch_frame/utils/skorch.py index a6300992..d16f1f42 100644 --- a/torch_frame/utils/skorch.py +++ b/torch_frame/utils/skorch.py @@ -1,26 +1,11 @@ -import skorch.utils - -# TODO: make it more safe -old_to_tensor = skorch.utils.to_tensor - -def to_tensor(X, device, accept_sparse=False): - if isinstance(X, TensorFrame): - return X - return old_to_tensor(X, device, accept_sparse) - -skorch.utils.to_tensor = to_tensor import importlib -importlib.reload(skorch.net) - from typing import Any -import pandas as pd +import skorch.utils import torch -import torch.nn as nn from numpy.typing import ArrayLike from pandas import DataFrame -from skorch import NeuralNet, NeuralNetClassifier -from skorch.dataset import Dataset as SkorchDataset +from skorch import NeuralNet from torch import Tensor import torch_frame @@ -29,12 +14,26 @@ def to_tensor(X, device, accept_sparse=False): TextEmbedderConfig, TextTokenizerConfig, ) -from torch_frame.data.dataset import DataFrameToTensorFrameConverter, Dataset +from torch_frame.data.dataset import Dataset from torch_frame.data.loader import DataLoader from torch_frame.data.tensor_frame import TensorFrame from torch_frame.typing import IndexSelectType from torch_frame.utils import infer_df_stype +# TODO: make it more safe +old_to_tensor = skorch.utils.to_tensor + + +def to_tensor(X, device, accept_sparse=False): + if isinstance(X, TensorFrame): + return X + return old_to_tensor(X, device, accept_sparse) + + +skorch.utils.to_tensor = to_tensor + +importlib.reload(skorch.net) + class NeuralNetPytorchFrameDataLoader(DataLoader): def __init__(self, dataset: Dataset | TensorFrame, *args, @@ -42,7 +41,7 @@ def __init__(self, dataset: Dataset | TensorFrame, *args, super().__init__(dataset, *args, **kwargs) self.device = device - def collate_fn( + def collate_fn( # type: ignore self, index: IndexSelectType) -> tuple[TensorFrame, Tensor | None]: index = torch.tensor(index) res = super().collate_fn(index).to(self.device)