Skip to content

Commit

Permalink
change val to test
Browse files Browse the repository at this point in the history
  • Loading branch information
gabhijith authored Jul 19, 2023
1 parent e72be04 commit 6e85d0f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions part2/2.JetTaggingMLP.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_val, y_train, y_val = train_test_split(features, target, test_size=0.33)\n",
"print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)\n",
"X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2)\n",
"print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)\n",
"del features, target"
]
},
Expand Down Expand Up @@ -318,7 +318,7 @@
"\n",
"# train \n",
"history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=batch_size, verbose = 2,\n",
" validation_data=(X_val, y_val), learning_rate=0.01,\n",
" validation_split=0.2,\n",
" # callbacks = [\n",
" # EarlyStopping(monitor='val_loss', patience=10, verbose=1),\n",
" # ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1),\n",
Expand Down Expand Up @@ -364,16 +364,16 @@
"source": [
"import pandas as pd\n",
"from sklearn.metrics import roc_curve, auc\n",
"predict_val = model.predict(X_val)\n",
"predict_test = model.predict(X_test)\n",
"df = pd.DataFrame()\n",
"fpr = {}\n",
"tpr = {}\n",
"auc1 = {}\n",
"\n",
"plt.figure()\n",
"for i, label in enumerate(labels):\n",
" df[label] = y_val[:,i]\n",
" df[label + '_pred'] = predict_val[:,i]\n",
" df[label] = y_test[:,i]\n",
" df[label + '_pred'] = predict_test[:,i]\n",
"\n",
" fpr[label], tpr[label], threshold = roc_curve(df[label],df[label+'_pred'])\n",
"\n",
Expand Down

0 comments on commit 6e85d0f

Please sign in to comment.