diff --git a/ontolearn/concept_learner.py b/ontolearn/concept_learner.py index bc2cc475..b69655d9 100644 --- a/ontolearn/concept_learner.py +++ b/ontolearn/concept_learner.py @@ -1271,15 +1271,16 @@ def fit_one(self, pos: Union[Set[OWLNamedIndividual], Set[str]], neg: Union[Set[ neg_str = neg else: raise ValueError(f"Invalid input type, was expecting OWLNamedIndividual or str but found {type(pos[0])}") - Pos = np.random.choice(pos_str, size=(self.num_predictions, len(pos_str)), replace=True) - Neg = np.random.choice(neg_str, size=(self.num_predictions, len(neg_str)), replace=True) - assert self.load_pretrained and self.m, \ "No pretrained model found. Please first train NCES2" - - dataset = NCES2DatasetInference([("", Pos_str, Neg_str) for (Pos_str, Neg_str) in zip(Pos, Neg)], - self.instance_embeddings, - self.vocab, self.inv_vocab, False, self.sorted_examples) + + #data, triples_data, k, vocab, inv_vocab, num_examples, sampling_strategy='p', num_pred_per_lp=1, random_sample=False + dataset = ROCESDatasetInference([("", pos_str, neg_str)], + self.triples_data, + self.vocab, self.inv_vocab, + self.num_examples, + sampling_strategy="nces2", + num_pred_per_lp=self.num_predictions) dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, diff --git a/ontolearn/data_struct.py b/ontolearn/data_struct.py index 16638b72..50a021a8 100644 --- a/ontolearn/data_struct.py +++ b/ontolearn/data_struct.py @@ -147,6 +147,81 @@ def clear(self): self.current_states.clear() self.next_states.clear() self.rewards.clear() + + +class CLIPDataset(torch.utils.data.Dataset): # pragma: no cover + + def __init__(self, data: list, embeddings, shuffle_examples, example_sizes: list=None, + k=5, sorted_examples=True): + super().__init__() + self.data = data + self.embeddings = embeddings + self.shuffle_examples = shuffle_examples + self.example_sizes = example_sizes + self.k = k + self.sorted_examples = sorted_examples + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + key, value = self.data[idx] + pos = value['positive examples'] + neg = value['negative examples'] + length = value['length'] + if self.example_sizes is not None: + k_pos, k_neg = random.choice(self.example_sizes) + k_pos = min(k_pos, len(pos)) + k_neg = min(k_neg, len(neg)) + selected_pos = random.sample(pos, k_pos) + selected_neg = random.sample(neg, k_neg) + elif self.k is not None: + prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k))) + prob_pos_set = prob_pos_set/prob_pos_set.sum() + prob_neg_set = 1.0/(1+np.array(range(min(self.k, len(neg)), len(neg)+1, self.k))) + prob_neg_set = prob_neg_set/prob_neg_set.sum() + k_pos = np.random.choice(range(min(self.k, len(pos)), len(pos)+1, self.k), replace=False, p=prob_pos_set) + k_neg = np.random.choice(range(min(self.k, len(neg)), len(neg)+1, self.k), replace=False, p=prob_neg_set) + selected_pos = random.sample(pos, k_pos) + selected_neg = random.sample(neg, k_neg) + else: + selected_pos = pos + selected_neg = neg + if self.shuffle_examples: + random.shuffle(selected_pos) + random.shuffle(selected_neg) + + datapoint_pos = torch.FloatTensor(self.embeddings.loc[selected_pos].values.squeeze()) + datapoint_neg = torch.FloatTensor(self.embeddings.loc[selected_neg].values.squeeze()) + + return datapoint_pos, datapoint_neg, torch.LongTensor([length]) + + +class CLIPDatasetInference(torch.utils.data.Dataset): # pragma: no cover + + def __init__(self, data: list, embeddings, shuffle_examples, + sorted_examples=True): + super().__init__() + self.data = data + self.embeddings = embeddings + self.shuffle_examples = shuffle_examples + self.sorted_examples = sorted_examples + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + _, pos, neg = self.data[idx] + if self.sorted_examples: + pos, neg = sorted(pos), sorted(neg) + elif self.shuffle_examples: + random.shuffle(pos) + random.shuffle(neg) + + datapoint_pos = torch.FloatTensor(self.embeddings.loc[pos].values.squeeze()) + datapoint_neg = torch.FloatTensor(self.embeddings.loc[pos].values.squeeze()) + + return datapoint_pos, datapoint_neg class NCESBaseDataset: # pragma: no cover @@ -256,85 +331,10 @@ def __getitem__(self, idx): datapoint_neg = torch.FloatTensor(self.embeddings.loc[neg].values.squeeze()) return datapoint_pos, datapoint_neg - - -class CLIPDataset(torch.utils.data.Dataset): # pragma: no cover - - def __init__(self, data: list, embeddings, shuffle_examples, example_sizes: list=None, - k=5, sorted_examples=True): - super().__init__() - self.data = data - self.embeddings = embeddings - self.shuffle_examples = shuffle_examples - self.example_sizes = example_sizes - self.k = k - self.sorted_examples = sorted_examples - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - key, value = self.data[idx] - pos = value['positive examples'] - neg = value['negative examples'] - length = value['length'] - if self.example_sizes is not None: - k_pos, k_neg = random.choice(self.example_sizes) - k_pos = min(k_pos, len(pos)) - k_neg = min(k_neg, len(neg)) - selected_pos = random.sample(pos, k_pos) - selected_neg = random.sample(neg, k_neg) - elif self.k is not None: - prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k))) - prob_pos_set = prob_pos_set/prob_pos_set.sum() - prob_neg_set = 1.0/(1+np.array(range(min(self.k, len(neg)), len(neg)+1, self.k))) - prob_neg_set = prob_neg_set/prob_neg_set.sum() - k_pos = np.random.choice(range(min(self.k, len(pos)), len(pos)+1, self.k), replace=False, p=prob_pos_set) - k_neg = np.random.choice(range(min(self.k, len(neg)), len(neg)+1, self.k), replace=False, p=prob_neg_set) - selected_pos = random.sample(pos, k_pos) - selected_neg = random.sample(neg, k_neg) - else: - selected_pos = pos - selected_neg = neg - if self.shuffle_examples: - random.shuffle(selected_pos) - random.shuffle(selected_neg) - - datapoint_pos = torch.FloatTensor(self.embeddings.loc[selected_pos].values.squeeze()) - datapoint_neg = torch.FloatTensor(self.embeddings.loc[selected_neg].values.squeeze()) - - return datapoint_pos, datapoint_neg, torch.LongTensor([length]) - - -class CLIPDatasetInference(torch.utils.data.Dataset): # pragma: no cover - - def __init__(self, data: list, embeddings, shuffle_examples, - sorted_examples=True): - super().__init__() - self.data = data - self.embeddings = embeddings - self.shuffle_examples = shuffle_examples - self.sorted_examples = sorted_examples - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - _, pos, neg = self.data[idx] - if self.sorted_examples: - pos, neg = sorted(pos), sorted(neg) - elif self.shuffle_examples: - random.shuffle(pos) - random.shuffle(neg) - - datapoint_pos = torch.FloatTensor(self.embeddings.loc[pos].values.squeeze()) - datapoint_neg = torch.FloatTensor(self.embeddings.loc[pos].values.squeeze()) - - return datapoint_pos, datapoint_neg -class ROCESDataset(BaseDataset, torch.utils.data.Dataset): +class ROCESDataset(NCESBaseDataset, torch.utils.data.Dataset): def __init__(self, data, triples_data, vocab, inv_vocab, sampling_strategy="p"): super(ROCESDataset, self).__init__(vocab, inv_vocab) @@ -360,16 +360,28 @@ def __getitem__(self, idx): key, value = self.data[idx] pos = value['positive examples'] neg = value['negative examples'] - if self.sampling_strategy != 'uniform': + if self.sampling_strategy == 'p': prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k))) prob_pos_set = prob_pos_set/prob_pos_set.sum() prob_neg_set = 1.0/(1+np.array(range(min(self.k, len(neg)), len(neg)+1, self.k))) prob_neg_set = prob_neg_set/prob_neg_set.sum() k_pos = np.random.choice(range(min(self.k, len(pos)), len(pos)+1, self.k), replace=False, p=prob_pos_set) k_neg = np.random.choice(range(min(self.k, len(neg)), len(neg)+1, self.k), replace=False, p=prob_neg_set) + elif self.sampling_strategy == 'nces2': + if random.random() > 0.5: + prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k))) + prob_pos_set = prob_pos_set/prob_pos_set.sum() + prob_neg_set = 1.0/(1+np.array(range(min(self.k, len(neg)), len(neg)+1, self.k))) + prob_neg_set = prob_neg_set/prob_neg_set.sum() + k_pos = max(1, 2*len(pos)//3) + k_neg = max(1, 2*len(neg)//3) + else: + k_pos = len(pos) + k_neg = len(neg) else: k_pos = np.random.choice(range(min(self.k, len(pos)), len(pos)+1, self.k), replace=False) k_neg = np.random.choice(range(min(self.k, len(neg)), len(neg)+1, self.k), replace=False) + selected_pos = random.sample(pos, k_pos) selected_neg = random.sample(neg, k_neg) @@ -380,9 +392,9 @@ def __getitem__(self, idx): return datapoint_pos, datapoint_neg, torch.cat([torch.tensor(labels), self.vocab['PAD']*torch.ones(max(0,self.max_length-length))]).long() -class ROCESDatasetInference(BaseDataset, torch.utils.data.Dataset): +class ROCESDatasetInference(NCESBaseDataset, torch.utils.data.Dataset): - def __init__(self, data, triples_data, k, vocab, inv_vocab, sampling_strategy, num_examples, num_pred_per_lp=1, random_sample=False): + def __init__(self, data, triples_data, k, vocab, inv_vocab, num_examples, sampling_strategy='p', num_pred_per_lp=1): super(ROCESDatasetInference, self).__init__(vocab, inv_vocab) self.data = data self.triples_data = triples_data @@ -408,7 +420,7 @@ def __getitem__(self, idx): pos = value['positive examples'] neg = value['negative examples'] - if self.sampling_strategy != 'uniform': + if self.sampling_strategy == 'p': prob_pos_set = 1.0/(1+np.array(range(min(self.k, len(pos)), len(pos)+1, self.k))) prob_pos_set = prob_pos_set/prob_pos_set.sum() prob_neg_set = 1.0/(1+np.array(range(min(self.k, len(neg)), len(neg)+1, self.k))) @@ -419,6 +431,13 @@ def __getitem__(self, idx): k_neg = np.random.choice(range(min(self.k, len(neg)), len(neg)+1, self.k), size=(self.num_predictions,), replace=True, p=prob_neg_set) + elif self.sampling_strategy == "nces2": + k_pos = np.random.choice([len(pos), 2*len(pos)//3], + size=(self.num_predictions,), + replace=True) + k_neg = np.random.choice([len(neg), 2*len(neg)//3], + size=(self.num_predictions,), + replace=True) else: k_pos = np.random.choice(range(min(self.k, len(pos)), len(pos)+1, self.k), size=(self.num_predictions,), replace=True)