Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I got this error while trying to learn tabddpm #301

Open
limhasic opened this issue Nov 4, 2024 · 0 comments
Open

I got this error while trying to learn tabddpm #301

limhasic opened this issue Nov 4, 2024 · 0 comments

Comments

@limhasic
Copy link

limhasic commented Nov 4, 2024

from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.utils.serialization import save_to_file, load_from_file
from sklearn.ensemble import VotingClassifier
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
import pandas as pd
import numpy as np
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import SuccessiveHalvingPruner
from sklearn.metrics import f1_score
from lightgbm import LGBMClassifier
import json
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from synthcity.utils.optuna_sample import suggest_all
tabddpm_model = Plugins().get("ddpm")
trial = optuna.create_study().ask()
params = suggest_all(trial, tabddpm_model.hyperparameter_space())

params['lr'] = 0.0001
params['num_timesteps'] = 100 # 데이터 생성 과정에서 30단계의 시간을 통해 모델이 데이터를 생성
params['batch_size'] = 10
params['n_iter'] = 1000

train_data = pd.read_csv('train_12000.csv').drop('Unnamed: 0', axis=1)

subset -> train_data

loader = GenericDataLoader(train_data, target_column="Fraud_Type")

TabDDPM 모델 초기화 및 학습

PLUGIN = "ddpm"
plugin_cls = type(Plugins().get(PLUGIN))
plugin = plugin_cls(**params).fit(loader)
#save_to_file(f'ddpm_{fraud_type}.pkl', syn_model)
#save_to_file(f'ddpm_{fraud_type}.pkl', plugin)
save_to_file(f'ddpm_all.pkl', plugin)
#reloaded = load_from_file('./adsgan_10_epochs.pkl')

합성 데이터 생성

#synthetic_subset = plugin.generate(count=N_Syn_len)
print('done')

I got this error while trying to learn tabddpm

Epoch: 0%| | 0/1000 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [0,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [1,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [3,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [7,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [8,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
Epoch: 0%| | 0/1000 [00:00<?, ?it/s]

RuntimeError Traceback (most recent call last)
Cell In[1], line 38
36 PLUGIN = "ddpm"
37 plugin_cls = type(Plugins().get(PLUGIN))
---> 38 plugin = plugin_cls(**params).fit(loader)
39 #save_to_file(f'ddpm_{fraud_type}.pkl', syn_model)
40 #save_to_file(f'ddpm_{fraud_type}.pkl', plugin)
41 save_to_file(f'ddpm_all.pkl', plugin)

File /opt/venv/lib/python3.10/site-packages/pydantic/decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File /opt/venv/lib/python3.10/site-packages/pydantic/decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File /opt/venv/lib/python3.10/site-packages/pydantic/decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/plugin.py:256, in Plugin.fit(self, X, *args, **kwargs)
248 X, self.compress_context = load_from_file(bkp_file)
250 self._training_schema = Schema(
251 data=X,
252 sampling_strategy=self.sampling_strategy,
253 random_state=self.random_state,
254 )
--> 256 output = self._fit(X, *args, **kwargs)
257 self.fitted = True
259 return output

File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/generic/plugin_ddpm.py:230, in TabDDPMPlugin._fit(self, X, *args, **kwargs)
227 self.expecting_conditional = True
229 # NOTE: cond may also be included in the dataframe
--> 230 self.model.fit(df, cond, **kwargs)
231 self.loss_history = self.model.loss_history
232 self.validation_history = self.model.val_history

File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/init.py:163, in TabDDPM.fit(self, X, cond, **kwargs)
161 self.optimizer.zero_grad()
162 args = (x,) if cond is None else (x, y)
--> 163 loss_multi, loss_gauss = self.diffusion.mixed_loss(*args)
164 loss = loss_multi + loss_gauss
165 loss.backward()

File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py:665, in GaussianMultinomialDiffusion.mixed_loss(self, x, cond)
663 x_num_t = self.gaussian_q_sample(x_num, t, noise=noise)
664 if x_cat.shape[1] > 0:
--> 665 log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes)
666 log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t)
668 x_in = torch.cat([x_num_t, log_x_cat_t], dim=1)

File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/utils.py:145, in index_to_log_onehot(x, num_classes)
143 onehots = []
144 for i in range(len(num_classes)):
--> 145 onehots.append(F.one_hot(x[:, i], num_classes[i]))
146 x_onehot = torch.cat(onehots, dim=1)
147 log_onehot = torch.log(x_onehot.float().clamp(min=1e-30))

RuntimeError: CUDA error: device-side assert triggered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant