diff --git a/part4/4.JetTaggingPointCloud.ipynb b/part4/4.JetTaggingPointCloud.ipynb new file mode 100644 index 0000000..e0227f2 --- /dev/null +++ b/part4/4.JetTaggingPointCloud.ipynb @@ -0,0 +1,729 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BsonEqBekjyy" + }, + "source": [ + "# Jet Tagging with Set Transformer\n", + "\n", + "In this notebook we will see an implementation of the Transformer architecture for sets applied to the jet tagging task. For *sets* it is meant here a point cloud, i.e. a set of nodes without edges. We will instead use Multi-Head Attention to learn which nodes (or particles) have strong pair-wise interaction.\n", + "\n", + "The architecture was introduced by [J. Lee at al. (ICML 2019)](https://arxiv.org/abs/1810.00825) -- specifically designed to model interactions among elements in the input set without pre-defined edges. The model consists of an encoder and a decoder, both of which rely on attention mechanisms, as in the original Transformer implementation [by Vaswani](https://arxiv.org/abs/1706.03762). The main difference is that positional encoding is removed plus some other low level adaptions.\n", + "\n", + "We will use tensorflow for this implementation.\n", + "\n", + "Before you start, choose GPU as a hardware accelerator for this notebook. To do this first go to Edit -> Notebook Settings -> Choose GPU as a hardware accelerator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "A7OS3w5WRSCj" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow.keras import layers\n", + "import h5py\n", + "import numpy as np\n", + "\n", + "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qwekVVRzneqU" + }, + "source": [ + "## Dataset exploration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BHQGiyC4Pr4R" + }, + "outputs": [], + "source": [ + "! curl https://cernbox.cern.ch/s/6Ec5pGFEpFWeH6S/download -o Data-MLtutorial.tar.gz\n", + "! tar -xvzf Data-MLtutorial.tar.gz \n", + "! ls Data-MLtutorial/JetDataset/\n", + "! rm Data-MLtutorial.tar.gz " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J9ZLcoKpPteG" + }, + "outputs": [], + "source": [ + "# let's open the file\n", + "data_dir = 'Data-MLtutorial/JetDataset/'\n", + "fileIN = data_dir+'jetImage_7_100p_30000_40000.h5'\n", + "f = h5py.File(fileIN)\n", + "# and see what it contains\n", + "print(list(f.keys()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ktx1VjNoOu4c" + }, + "source": [ + "* 'jetImage' ,' jetImageECAL' and 'jetImageHCAL' contains the image representation of the jets . We will not use them today but build our point cloud from the other information.\n", + "* 'jetConstituentList' is the list of particles cointained in the jet. For each particle, a list of relevant quantities is stored. This is the dataset we will consider in this notebook.\n", + "* 'particleFeatureNames' is the list of the names corresponding to the quantities contained in 'jetConstituentList'\n", + "* 'jets' is the list of jets with the high-level jet features stored. We will only use jet ID from it, indecies [-6:-1]\n", + "* 'jetFeatureNames' is the list of the names corresponding to the quantities contained in 'jets'. These quantities are build using physics knowledge and correspond to high-level infromation and features per graph (as opposed to per node)\n", + "\n", + "The first 100 highest transverse momentum $p_T$ particles are considered for each jet.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Re7oXWWmPxz9" + }, + "outputs": [], + "source": [ + "target_onehot = np.array([])\n", + "jetList = np.array([])\n", + "jetImages = np.array([])\n", + "features_names = dict()\n", + "datafiles = ['jetImage_7_100p_0_10000.h5',\n", + " 'jetImage_7_100p_10000_20000.h5',\n", + " 'jetImage_7_100p_30000_40000.h5',\n", + " 'jetImage_7_100p_40000_50000.h5',\n", + " 'jetImage_7_100p_50000_60000.h5'\n", + " ]\n", + "for i_f,fileIN in enumerate(datafiles):\n", + " print(\"Appending %s\" %fileIN)\n", + " f = h5py.File(data_dir + fileIN)\n", + " jetList_file = np.array(f.get(\"jetConstituentList\"))\n", + " target_file = np.array(f.get('jets')[0:,-6:-1])\n", + " jetImages_file = np.array(f.get('jetImage'))\n", + " jetList = np.concatenate([jetList, jetList_file], axis=0) if jetList.size else jetList_file\n", + " target_onehot = np.concatenate([target_onehot, target_file], axis=0) if target_onehot.size else target_file\n", + " jetImages = np.concatenate([jetImages, jetImages_file], axis=0) if jetImages.size else jetImages_file\n", + " del jetList_file, target_file, jetImages_file\n", + " #save particles/nodes features names and their indecies in a dictionary\n", + " if i_f==0:\n", + " for feat_idx,feat_name in enumerate(list(f['particleFeatureNames'])[:-1]):\n", + " features_names[feat_name.decode(\"utf-8\").replace('j1_','')] = feat_idx\n", + " f.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7kQnL9vkP4rK" + }, + "source": [ + "The ground truth is incorporated in the ['j_g', 'j_q', 'j_w', 'j_z', 'j_t] vector of boolean, taking the form\n", + "* [1, 0, 0, 0, 0] for gluons\n", + "* [0, 1, 0, 0, 0] for quarks\n", + "* [0, 0, 1, 0, 0] for W\n", + "* [0, 0, 0, 1, 0] for Z \n", + "* [0, 0, 0, 0, 1] for top quarks\n", + "\n", + "This is what is called 'one-hot' encoding of a descrete label (typical of ground truth for classification problems). These labels are the 'target' for our classification tasks. Let's convert it back to single-column encoding :\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "84NSj2W7P477" + }, + "outputs": [], + "source": [ + "print(\"Labels for the first five entries in the dataset, one-hot encoded:\")\n", + "for i in range(5):\n", + " print(target_onehot[i])\n", + "print(target_onehot.shape)\n", + "target = np.argmax(target_onehot, axis=1)\n", + "print(target.shape)\n", + "print(\"Labels for the first five entries in the dataset, single column encoded:\")\n", + "for i in range(0,5):\n", + " print(target[i])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mqsd_aP__RIi" + }, + "source": [ + "Now our lables correspond to :\n", + "* 0 for gluons\n", + "* 1 for quarks\n", + "* 2 for W\n", + "* 3 for Z \n", + "* 4 for top quarks\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hyP15oxhP5ek" + }, + "outputs": [], + "source": [ + "num_classes = len(np.unique(target))\n", + "label_names= [\"gluon\", \"quark\", \"W\", \"Z\", \"top\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ik-6OX0LMJW7" + }, + "source": [ + "Now let's inspect our data. Each jet is a point cloud/graph with 100 particles/nodes, each of which has 16 features. We have a double-index dataset: (jet index, particle index). The list is cut at 100 constituents per jet. If less constituents are present in the jet/point cloud, the dataset is completed filling it with 0s (zero padding). Note : zero-padding is not using during the training, it is only used to store the ragged dataset.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YfHRopq0P8tW" + }, + "outputs": [], + "source": [ + "print('Jets shape : ',jetList.shape)\n", + "print('Target/Labels shape : ',target.shape)\n", + "print('Particles/Nodes features : ',list(features_names.keys()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VBYwH4t8MhHm" + }, + "source": [ + "We are not interested in all features for now. For now we will only consider the same node features as were considered in the ParticleNet paper: ![Screenshot 2022-09-26 at 16.28.34.png]()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QWtB3vTWP_QY" + }, + "outputs": [], + "source": [ + "features_to_consider = 'etarel,phirel,pt,e,ptrel,erel,deltaR'.split(',')\n", + "features_idx = [features_names[name] for name in features_to_consider]\n", + "jetList = jetList[:,:,features_idx]\n", + "print(jetList.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-M8uvPR4mfI7" + }, + "source": [ + "Let's define basics hyperparamters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2mHCuVm6ZJaY" + }, + "outputs": [], + "source": [ + "batch_size=128\n", + "learning_rate=0.0001\n", + "epochs=20" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "unpVhZNfmotI" + }, + "source": [ + "In the original paper, multi-head attention is also applied in the decoder step to obtain a smarter pooling operation. For this excercise we will simplify the model and use instead a `Lambda` layer to apply a custom pooling function to the input tensor. In this case, the `Lambda` layer is being used to sum over the first dimension, i.e. over the elements in the output set of the previous layer, which has shape `(batch_size, n_elements, features)`. By summing over the first dimension (`axis=1`), we obtain a tensor of shape `(batch_size, features)` that represents an aggregation of each feature over the elements in the set.\n", + "\n", + "Here is the full model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O7rzMn7wRcrP" + }, + "outputs": [], + "source": [ + "inputs = keras.Input(shape=(100,7), name='input')\n", + "x = layers.TimeDistributed(layers.Dense(64))(inputs)\n", + "x = layers.TimeDistributed(layers.LeakyReLU())(x)\n", + "x = layers.TimeDistributed(layers.Dense(64))(x)\n", + "x = layers.TimeDistributed(layers.LeakyReLU())(x)\n", + "x = layers.TimeDistributed(layers.Dense(64))(x)\n", + "x = layers.TimeDistributed(layers.LeakyReLU())(x)\n", + "x = layers.TimeDistributed(layers.Dense(64))(x)\n", + "x = layers.TimeDistributed(layers.LeakyReLU())(x)\n", + "x = layers.Lambda(lambda y: tf.reduce_sum(y, axis=1))(x)\n", + "x = layers.BatchNormalization()(x)\n", + "x = layers.Dense(64)(x)\n", + "x = layers.LeakyReLU()(x)\n", + "x = layers.Dense(64)(x)\n", + "x = layers.LeakyReLU()(x)\n", + "x = layers.Dense(16)(x)\n", + "x = layers.LeakyReLU()(x)\n", + "output = layers.Dense(5, dtype='float32')(x)\n", + "model = keras.models.Model(inputs=inputs, outputs=output)\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "G8NI-_bYdSAq" + }, + "outputs": [], + "source": [ + "model.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=keras.optimizers.Adam(learning_rate=learning_rate),\n", + " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DhAKhgMMcrwa" + }, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "X_train, X_val, y_train, y_val, y_train_onehot, y_val_onehot = train_test_split(jetList, target, target_onehot, test_size=0.1, shuffle=True)\n", + "print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)\n", + "del jetList, target, target_onehot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lz7rfyeCdNF0" + }, + "outputs": [], + "source": [ + "history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Egcr8vMhp-2v" + }, + "source": [ + "We can now plot the validation and training loss evolution over the epochs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sjTOMuzAqGEr" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "fig,axes = plt.subplots(2)\n", + "\n", + "axes[0].plot(history.history[\"sparse_categorical_accuracy\"])\n", + "axes[0].plot(history.history[\"val_sparse_categorical_accuracy\"])\n", + "axes[0].set_title(\"Accuracy\")\n", + "axes[0].legend([\"Training\", \"Validation\"])\n", + "\n", + "axes[1].plot(history.history[\"loss\"])\n", + "axes[1].plot(history.history[\"val_loss\"])\n", + "axes[1].legend([\"Training\", \"Validation\"])\n", + "axes[1].set_title(\"Loss\")\n", + "\n", + "fig.show()\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CwrPPStDrS4J" + }, + "source": [ + "Now we finally evaluate the performance by plotting the ROC curves for the different classes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JKM0yYFfecJh" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "from sklearn.metrics import roc_curve, auc\n", + "predict_val = tf.nn.softmax(model.predict(X_val))\n", + "df = pd.DataFrame()\n", + "fpr = {}\n", + "tpr = {}\n", + "auc1 = {}\n", + "\n", + "plt.figure()\n", + "for i, label in enumerate(label_names):\n", + "\n", + " df[label] = y_val_onehot[:,i]\n", + " df[label + '_pred'] = predict_val[:,i]\n", + "\n", + " fpr[label], tpr[label], threshold = roc_curve(df[label],df[label+'_pred'])\n", + "\n", + " auc1[label] = auc(fpr[label], tpr[label])\n", + "\n", + " plt.plot(tpr[label],fpr[label],label='%s tagger, auc = %.1f%%'%(label,auc1[label]*100.))\n", + "plt.semilogy()\n", + "plt.xlabel(\"sig. efficiency\")\n", + "plt.ylabel(\"bkg. mistag rate\")\n", + "plt.ylim(0.000001,1)\n", + "plt.grid(True)\n", + "plt.legend(loc='lower right')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IzxPDanrrYZB" + }, + "source": [ + "As you can see the performance are not as good for other models we have trained on the same dataset. As mentioned at the beginning of the notebook training a transformer might be tricky. You can try the optional excercise below to improve the performance and surpass the other models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi Head Attention recap\n", + "\n", + "Assume we have $n$ query vectors (corresponding to the $n$ elements in the set) each with dimension $d_q : Q \\in \\mathbb{R}^{n\\times d_q}$. In the jet tagging dataset $n=100$ and $d_q=7$.\n", + "\n", + "An attention function $\\mathrm{Att}(Q,K,V)$ is a function that maps queries $Q$ to outputs using $n_v$ key-value pairs $K \\in \\mathbb{R}^{n_v \\times d_q}, V \\in \\mathbb{R}^{n_v\\times d_v}$:\n", + "\n", + "$$\n", + "\\mathrm{Att}(Q,K,V;\\omega) = \\omega(QK^{T})V.\n", + "$$\n", + "\n", + "The pairwise dot product $QT^\\mathrm{T} \\in \\mathbb{R}^{n\\times n_v}$ measures how similar each pair of query and key vectors is, with weights computed with an activation function $\\omega$. The output $\\omega(QK^{T})V$ is a weighted sum of $V$ where a value gets more weight if its corresponding key has larger dot product with the query.\n", + "\n", + "Instead of computing a single attention function, the **multi-head attention** method first projects $Q, K, V$ onto $h$ different $d^M_q,d^M_q,d^M_v$-dimensional vectors, respectively. An attention function $\\mathrm{Att}(\\cdot; \\omega_j)$ is applied to each of these $h$ projections. The output is a linear transformation of the concatenation of all attention outputs:\n", + "\n", + "$$\n", + "\\mathrm{Multihead}(Q, K, V ; \\lambda, \\omega) = \\mathrm{concat}(O_1,..., O_h)W^O\n", + "$$\n", + "\n", + "$$\n", + "O_j = \\mathrm{Att}(QW^Q_j, KW^K_j, VW^V_j ; \\omega_j )\n", + "$$\n", + "\n", + "In other words, the model tells you what is the score of a particle in the set knowing its interaction with the other particles in the set given all features but in a way that the features are attended separately.\n", + "\n", + "Note that $\\mathrm{Multihead}(\\cdot, \\cdot, \\cdot; \\lambda)$ has learnable parameters $\\lambda =$ {$W^Q_j, W^K_j, W^V_j$}$_{j=1,...,h}$ where $W^Q_j, W^K_j \\in \\mathbb{R}^{d_q\\times d^M_q}, W^V_j \\in \\mathbb{R}^{d_v\\times d^M_v}, W^O \\in \\mathbb{R}^{hd^M_v\\times d}$. A typical choice for the dimension hyperparameters is $d^M_q = d_q /h, d^M_v = d_v /h, d = d_q$. For the Set Transformer we set $d_q = d_v = d$ and $d^M_q = d^M_v = d/h$. A scaled softmax $\\omega_j (\\cdot) = \\mathrm{softmax}(\\cdot/\\sqrt{d})$ is used.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building the Set Transformer\n", + "\n", + "We will implement a simplified version of the [original Set Transformer architecture](https://arxiv.org/abs/1810.00825). The reason is because Transformers are typically computationally and data hungry. As an optional excercise at the end of the notebook you can try to implement the full model and test it on a simpler problem like the MNIST dataset classification (or on a larger jet class dataset).\n", + "\n", + "The architecture is based on the block called `MAB` (= Multihead Attention Block) which implements the following:\n", + "\n", + "$$\n", + "\\mathrm{MAB}(X, Y) = \\mathrm{LayerNorm}(H + \\mathrm{rFF}(H))\n", + "$$\n", + "\n", + "$$\n", + "H = \\mathrm{LayerNorm}(X + \\mathrm{Multihead}(X, X, X ; ω))\n", + "$$\n", + "\n", + "where $X \\in \\mathbb{R}^{n\\times d}$ is the input set and $\\mathrm{rFF}$ is any feedforward layer. Since $Q=K=V=X$, the MAB takes a set and performs *self-attention* between the elements in the set, resulting in a set of equal size. Since the output of MAB contains information about pairwise interactions among the elements in the input set $X$, we can stack multiple MABs to encode higher order interactions. This stack is the *encoder* part of the transformer. \n", + "\n", + "The `LayerNorm` normalizes the activations of a layer across the last dimension (feature dimension) of the input tensor. Specifically, it centers and scales each feature dimension independently by subtracting the mean and dividing by the standard deviation, which are computed over the corresponding feature dimension of the input tensor. As for `BatchNormalization` it has learnable $\\gamma$ (scaling) and $\\beta$ (shifting) parameters. The difference with respect to `BatchNormalization` is that the normalization is performed indipendently per each instance in the batch. `LayerNorm` leads to improved stability when you expect instances of different sizes (or different zero padding degree as in the jet tagging case)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RORSIHwVRPx4" + }, + "outputs": [], + "source": [ + "class SABTransformerBlock(tf.keras.layers.Layer):\n", + " def __init__(self, num_heads, hidden_units, mlp_hidden_units=128, dropout_rate=0.1, **kwargs):\n", + " super(SABTransformerBlock, self).__init__(**kwargs)\n", + " self.num_heads = num_heads\n", + " self.hidden_units = hidden_units\n", + " self.mlp_hidden_units = mlp_hidden_units\n", + " self.dropout_rate = dropout_rate\n", + "\n", + " def build(self, input_shape):\n", + " self.attention = tf.keras.layers.MultiHeadAttention(num_heads=self.num_heads, \n", + " key_dim=self.hidden_units//self.num_heads)\n", + " self.feedforward = tf.keras.Sequential([\n", + " Dense(units=self.mlp_hidden_units, activation=\"relu\"),\n", + " # Dropout(rate=self.dropout_rate),\n", + " Dense(units=input_shape[-1])\n", + " ])\n", + " self.layer_norm1 = LayerNormalization(epsilon=1e-6)\n", + " self.layer_norm2 = LayerNormalization(epsilon=1e-6)\n", + " self.dropout1 = Dropout(rate=self.dropout_rate)\n", + " self.dropout2 = Dropout(rate=self.dropout_rate)\n", + " super(SABTransformerBlock, self).build(input_shape)\n", + " \n", + " def call(self, inputs, mask=None):\n", + " attention_output = self.attention(inputs, inputs, attention_mask=mask)[0]\n", + " # attention_output = self.dropout1(attention_output)\n", + " attention_output = self.layer_norm1(inputs + attention_output)\n", + " feedforward_output = self.feedforward(attention_output)\n", + " # feedforward_output = self.dropout2(feedforward_output)\n", + " block_output = self.layer_norm2(attention_output + feedforward_output)\n", + " return block_output\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = keras.Input(shape=(100,7), name='input')\n", + "x = layers.TimeDistributed(layers.Dense(64))(inputs)\n", + "x = SABTransformerBlock(num_heads=8, hidden_units=64)(x)\n", + "x = SABTransformerBlock(num_heads=8, hidden_units=64)(x)\n", + "x = SABTransformerBlock(num_heads=8, hidden_units=64)(x)\n", + "x = layers.Lambda(lambda y: tf.reduce_sum(y, axis=1))(x)\n", + "x = layers.BatchNormalization()(x)\n", + "x = layers.Dense(64)(x)\n", + "x = layers.LeakyReLU()(x)\n", + "x = layers.Dense(64)(x)\n", + "x = layers.LeakyReLU()(x)\n", + "x = layers.Dense(16)(x)\n", + "x = layers.LeakyReLU()(x)\n", + "output = layers.Dense(5, dtype='float32')(x)\n", + "model_st = keras.models.Model(inputs=inputs, outputs=output)\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=keras.optimizers.Adam(learning_rate=learning_rate),\n", + " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "fig,axes = plt.subplots(2)\n", + "\n", + "axes[0].plot(history.history[\"sparse_categorical_accuracy\"])\n", + "axes[0].plot(history.history[\"val_sparse_categorical_accuracy\"])\n", + "axes[0].set_title(\"Accuracy\")\n", + "axes[0].legend([\"Training\", \"Validation\"])\n", + "\n", + "axes[1].plot(history.history[\"loss\"])\n", + "axes[1].plot(history.history[\"val_loss\"])\n", + "axes[1].legend([\"Training\", \"Validation\"])\n", + "axes[1].set_title(\"Loss\")\n", + "\n", + "fig.show()\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we finally evaluate the performance by plotting the ROC curves for the different classes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "from sklearn.metrics import roc_curve, auc\n", + "predict_val = tf.nn.softmax(model.predict(X_val))\n", + "df = pd.DataFrame()\n", + "fpr = {}\n", + "tpr = {}\n", + "auc1 = {}\n", + "\n", + "plt.figure()\n", + "for i, label in enumerate(label_names):\n", + "\n", + " df[label] = y_val_onehot[:,i]\n", + " df[label + '_pred'] = predict_val[:,i]\n", + "\n", + " fpr[label], tpr[label], threshold = roc_curve(df[label],df[label+'_pred'])\n", + "\n", + " auc1[label] = auc(fpr[label], tpr[label])\n", + "\n", + " plt.plot(tpr[label],fpr[label],label='%s tagger, auc = %.1f%%'%(label,auc1[label]*100.))\n", + "plt.semilogy()\n", + "plt.xlabel(\"sig. efficiency\")\n", + "plt.ylabel(\"bkg. mistag rate\")\n", + "plt.ylim(0.000001,1)\n", + "plt.grid(True)\n", + "plt.legend(loc='lower right')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optional Excercise\n", + "\n", + "The original paper also use MH mechanism in the decoder step (while we used a simple sum over the latent space nodes). If you would like to try it out the `Lambda` layer should be replaced with the `PoolingByMultiHeadAttention` block below.\n", + "\n", + "Consider also the fact that it might be hard to train a Transformer architecture of this kind over the rather small dataset used here. Check out [this other dataset](https://events.mcs.cmu.edu/us-cms-2023/) for increased statistics or [this notebook](https://github.com/DLii-Research/tf-settransformer/blob/master/examples/mnist_pointcloud.ipynb) for a simpler task.\n", + "\n", + "Below is the starting point for a smarter decoder:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9qHj1_Y7ZU-R" + }, + "outputs": [], + "source": [ + "class PoolingByMultiHeadAttention(tf.keras.layers.Layer):\n", + " def __init__(self, num_heads, hidden_units, mlp_hidden_units=128, num_seeds=1, **kwargs):\n", + " super(PoolingByMultiHeadAttention, self).__init__(**kwargs)\n", + " self.num_heads = num_heads\n", + " self.hidden_units = hidden_units\n", + " self.mlp_hidden_units = mlp_hidden_units\n", + " self.num_seeds = num_seeds\n", + " \n", + " def build(self, input_shape):\n", + " \n", + " self.attention = tf.keras.layers.MultiHeadAttention(num_heads=self.num_heads, \n", + " key_dim=self.hidden_units)\n", + " \n", + " self.seed_vectors = self.add_weight(\n", + " shape=(1, self.num_seeds, self.hidden_units),\n", + " initializer=\"random_normal\",\n", + " trainable=True,\n", + " name=\"Seeds\")\n", + "\n", + " self.feedforward = tf.keras.Sequential([\n", + " layers.Dense(units=self.mlp_hidden_units, activation=\"relu\"),\n", + " layers.Dense(units=self.hidden_units)\n", + " ])\n", + " self.layer_norm1 = layers.LayerNormalization(epsilon=1e-6)\n", + " self.layer_norm2 = layers.LayerNormalization(epsilon=1e-6)\n", + " super(PoolingByMultiHeadAttention, self).build(input_shape)\n", + "\n", + " def call(self, inputs, training=None):\n", + " a = tf.expand_dims(self.seed_vectors, axis=0)\n", + " seeds = tf.tile(self.seed_vectors, [tf.shape(inputs)[0], 1, 1])\n", + " attention_output = self.attention(seeds, inputs)[0]\n", + " attention_output = self.layer_norm1(seeds + attention_output)\n", + " feedforward_output = self.feedforward(attention_output)\n", + " block_output = self.layer_norm2(attention_output + feedforward_output)\n", + " return block_output" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyPn4xtio5MeIQMG/e23naQt", + "include_colab_link": true, + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}