You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.utils.serialization import save_to_file, load_from_file
from sklearn.ensemble import VotingClassifier
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
import pandas as pd
import numpy as np
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import SuccessiveHalvingPruner
from sklearn.metrics import f1_score
from lightgbm import LGBMClassifier
import json
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/generic/plugin_ddpm.py:230, in TabDDPMPlugin._fit(self, X, *args, **kwargs)
227 self.expecting_conditional = True
229 # NOTE: cond may also be included in the dataframe
--> 230 self.model.fit(df, cond, **kwargs)
231 self.loss_history = self.model.loss_history
232 self.validation_history = self.model.val_history
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/init.py:163, in TabDDPM.fit(self, X, cond, **kwargs)
161 self.optimizer.zero_grad()
162 args = (x,) if cond is None else (x, y)
--> 163 loss_multi, loss_gauss = self.diffusion.mixed_loss(*args)
164 loss = loss_multi + loss_gauss
165 loss.backward()
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.utils.serialization import save_to_file, load_from_file
from sklearn.ensemble import VotingClassifier
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
import pandas as pd
import numpy as np
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import SuccessiveHalvingPruner
from sklearn.metrics import f1_score
from lightgbm import LGBMClassifier
import json
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from synthcity.utils.optuna_sample import suggest_all
tabddpm_model = Plugins().get("ddpm")
trial = optuna.create_study().ask()
params = suggest_all(trial, tabddpm_model.hyperparameter_space())
params['lr'] = 0.0001
params['num_timesteps'] = 100 # 데이터 생성 과정에서 30단계의 시간을 통해 모델이 데이터를 생성
params['batch_size'] = 10
params['n_iter'] = 1000
train_data = pd.read_csv('train_12000.csv').drop('Unnamed: 0', axis=1)
subset -> train_data
loader = GenericDataLoader(train_data, target_column="Fraud_Type")
TabDDPM 모델 초기화 및 학습
PLUGIN = "ddpm"
plugin_cls = type(Plugins().get(PLUGIN))
plugin = plugin_cls(**params).fit(loader)
#save_to_file(f'ddpm_{fraud_type}.pkl', syn_model)
#save_to_file(f'ddpm_{fraud_type}.pkl', plugin)
save_to_file(f'ddpm_all.pkl', plugin)
#reloaded = load_from_file('./adsgan_10_epochs.pkl')
합성 데이터 생성
#synthetic_subset = plugin.generate(count=N_Syn_len)
print('done')
I got this error while trying to learn tabddpm
Epoch: 0%| | 0/1000 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [0,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [1,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [3,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [7,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [8,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.Epoch: 0%| | 0/1000 [00:00<?, ?it/s]
RuntimeError Traceback (most recent call last)
Cell In[1], line 38
36 PLUGIN = "ddpm"
37 plugin_cls = type(Plugins().get(PLUGIN))
---> 38 plugin = plugin_cls(**params).fit(loader)
39 #save_to_file(f'ddpm_{fraud_type}.pkl', syn_model)
40 #save_to_file(f'ddpm_{fraud_type}.pkl', plugin)
41 save_to_file(f'ddpm_all.pkl', plugin)
File /opt/venv/lib/python3.10/site-packages/pydantic/decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()
File /opt/venv/lib/python3.10/site-packages/pydantic/decorator.py:134, in pydantic.decorator.ValidatedFunction.call()
File /opt/venv/lib/python3.10/site-packages/pydantic/decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/plugin.py:256, in Plugin.fit(self, X, *args, **kwargs)
248 X, self.compress_context = load_from_file(bkp_file)
250 self._training_schema = Schema(
251 data=X,
252 sampling_strategy=self.sampling_strategy,
253 random_state=self.random_state,
254 )
--> 256 output = self._fit(X, *args, **kwargs)
257 self.fitted = True
259 return output
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/generic/plugin_ddpm.py:230, in TabDDPMPlugin._fit(self, X, *args, **kwargs)
227 self.expecting_conditional = True
229 # NOTE: cond may also be included in the dataframe
--> 230 self.model.fit(df, cond, **kwargs)
231 self.loss_history = self.model.loss_history
232 self.validation_history = self.model.val_history
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/init.py:163, in TabDDPM.fit(self, X, cond, **kwargs)
161 self.optimizer.zero_grad()
162 args = (x,) if cond is None else (x, y)
--> 163 loss_multi, loss_gauss = self.diffusion.mixed_loss(*args)
164 loss = loss_multi + loss_gauss
165 loss.backward()
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py:665, in GaussianMultinomialDiffusion.mixed_loss(self, x, cond)
663 x_num_t = self.gaussian_q_sample(x_num, t, noise=noise)
664 if x_cat.shape[1] > 0:
--> 665 log_x_cat = index_to_log_onehot(x_cat.long(), self.num_classes)
666 log_x_cat_t = self.q_sample(log_x_start=log_x_cat, t=t)
668 x_in = torch.cat([x_num_t, log_x_cat_t], dim=1)
File /opt/venv/lib/python3.10/site-packages/synthcity/plugins/core/models/tabular_ddpm/utils.py:145, in index_to_log_onehot(x, num_classes)
143 onehots = []
144 for i in range(len(num_classes)):
--> 145 onehots.append(F.one_hot(x[:, i], num_classes[i]))
146 x_onehot = torch.cat(onehots, dim=1)
147 log_onehot = torch.log(x_onehot.float().clamp(min=1e-30))
RuntimeError: CUDA error: device-side assert triggered
Compile with
TORCH_USE_CUDA_DSA
to enable device-side assertions.The text was updated successfully, but these errors were encountered: