We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Constructing datasets for the contrastive model is currently adhoc and could benefit from some utility funcs. For example.
# Compute the indicies for query, index, val, and train splits query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], [] for cid in range(ds_info.features["label"].num_classes): idxs = tf.random.shuffle(tf.where(y_raw_train == cid)) idxs = tf.reshape(idxs, (-1,)) query_idxs.extend(idxs[:100]) # 200 query examples per class index_idxs.extend(idxs[100:200]) # 200 index examples per class val_idxs.extend(idxs[200:300]) # 100 validation examples per class train_idxs.extend(idxs[300:]) # The remaining are used for training random.shuffle(query_idxs) random.shuffle(index_idxs) random.shuffle(val_idxs) random.shuffle(train_idxs) def create_split(idxs: list) -> tuple: x, y = [], [] for idx in idxs: x.append(x_raw_train[int(idx)]) y.append(y_raw_train[int(idx)]) return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor( np.array(y), dtype=tf.int64 ) x_query, y_query = create_split(query_idxs) x_index, y_index = create_split(index_idxs) x_val, y_val = create_split(val_idxs) x_train, y_train = create_split(train_idxs)
The text was updated successfully, but these errors were encountered:
owenvallis
No branches or pull requests
Constructing datasets for the contrastive model is currently adhoc and could benefit from some utility funcs. For example.
The text was updated successfully, but these errors were encountered: