Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Contrastive Sampler / Utilities for constructing datasets. #298

Open
owenvallis opened this issue Oct 6, 2022 · 0 comments
Open
Assignees
Labels
component:samplers Data sampling related type:feature New feature

Comments

@owenvallis
Copy link
Collaborator

Constructing datasets for the contrastive model is currently adhoc and could benefit from some utility funcs. For example.

# Compute the indicies for query, index, val, and train splits
query_idxs, index_idxs, val_idxs, train_idxs = [], [], [], []

for cid in range(ds_info.features["label"].num_classes):
    idxs = tf.random.shuffle(tf.where(y_raw_train == cid))
    idxs = tf.reshape(idxs, (-1,))
    query_idxs.extend(idxs[:100])  # 200 query examples per class
    index_idxs.extend(idxs[100:200])  # 200 index examples per class
    val_idxs.extend(idxs[200:300])  # 100 validation examples per class
    train_idxs.extend(idxs[300:])  # The remaining are used for training


random.shuffle(query_idxs)
random.shuffle(index_idxs)
random.shuffle(val_idxs)
random.shuffle(train_idxs)


def create_split(idxs: list) -> tuple:
    x, y = [], []
    for idx in idxs:
        x.append(x_raw_train[int(idx)])
        y.append(y_raw_train[int(idx)])
    return tf.convert_to_tensor(np.array(x), dtype=tf.float32), tf.convert_to_tensor(
        np.array(y), dtype=tf.int64
    )


x_query, y_query = create_split(query_idxs)
x_index, y_index = create_split(index_idxs)
x_val, y_val = create_split(val_idxs)
x_train, y_train = create_split(train_idxs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:samplers Data sampling related type:feature New feature
Projects
None yet
Development

No branches or pull requests

1 participant