From 90aaf2e81e54181a7f8a25a6ce797e3e440ce11c Mon Sep 17 00:00:00 2001 From: Foivos Tsimpourlas Date: Wed, 10 Aug 2022 23:44:10 -0700 Subject: [PATCH 1/3] [WIP] Add uniform dataset sampler --- compiler_gym/datasets/datasets.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py index 05ecf0e45..b5f6f4af5 100644 --- a/compiler_gym/datasets/datasets.py +++ b/compiler_gym/datasets/datasets.py @@ -234,6 +234,32 @@ def benchmarks(self, with_deprecated: bool = False) -> Iterable[Benchmark]: (d.benchmarks() for d in self.datasets(with_deprecated=with_deprecated)) ) + def benchmarks_from_distrib( + self, + datasets: List[Dataset] = None, + weights: List[float] = None, + dataset_size: int = -1, + ) -> Iterable[Benchmark]: + """ + Foivos WIP. + Select a dataset to sample from with some weight probability. + If weights is None, select among `datasets` uniformly. + """ + datasets = datasets or self.datasets + if weights is None: + weights = [1 / len(datasets)] * len(datasets) + if len(weights) != len(datasets): + raise ValueError( + "Mismatch between datasets size: {} and sampling weights length: {}!".format( + len(datasets), len(weights) + ) + ) + idx = 0 + while dataset_size == -1 or idx < dataset_size: + sampled = np.random.choice(datasets, p=weights) + yield sampled.sample() + return + def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]: """Enumerate the (possibly infinite) benchmark URIs. From df4d8141148374ae1760696d8df7901c3957296d Mon Sep 17 00:00:00 2001 From: Foivos Tsimpourlas Date: Fri, 12 Aug 2022 15:37:30 -0700 Subject: [PATCH 2/3] Wrap sampled dataset with round robin iterator --- compiler_gym/datasets/datasets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py index b5f6f4af5..9674c49dc 100644 --- a/compiler_gym/datasets/datasets.py +++ b/compiler_gym/datasets/datasets.py @@ -236,7 +236,7 @@ def benchmarks(self, with_deprecated: bool = False) -> Iterable[Benchmark]: def benchmarks_from_distrib( self, - datasets: List[Dataset] = None, + datasets: List[str] = None, weights: List[float] = None, dataset_size: int = -1, ) -> Iterable[Benchmark]: @@ -245,7 +245,7 @@ def benchmarks_from_distrib( Select a dataset to sample from with some weight probability. If weights is None, select among `datasets` uniformly. """ - datasets = datasets or self.datasets + datasets = datasets or self._datasets.values() if weights is None: weights = [1 / len(datasets)] * len(datasets) if len(weights) != len(datasets): @@ -256,8 +256,11 @@ def benchmarks_from_distrib( ) idx = 0 while dataset_size == -1 or idx < dataset_size: - sampled = np.random.choice(datasets, p=weights) - yield sampled.sample() + sampled_key = np.random.choice(datasets, p=weights) + if sampled_key not in self._datasets: + raise LookupError(f"Dataset not found: {sampled_key}") + dataset = self._datasets[sampled_key] + return round_robin_iterables((dataset,)) return def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]: From 71b53ad7348f5b0845cd050ebd7cb8fcf79f2d27 Mon Sep 17 00:00:00 2001 From: Foivos Tsimpourlas Date: Mon, 15 Aug 2022 17:16:13 -0700 Subject: [PATCH 3/3] convert dict values to list --- compiler_gym/datasets/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py index 9674c49dc..0bf16e573 100644 --- a/compiler_gym/datasets/datasets.py +++ b/compiler_gym/datasets/datasets.py @@ -245,7 +245,7 @@ def benchmarks_from_distrib( Select a dataset to sample from with some weight probability. If weights is None, select among `datasets` uniformly. """ - datasets = datasets or self._datasets.values() + datasets = datasets or list(self._datasets.values()) if weights is None: weights = [1 / len(datasets)] * len(datasets) if len(weights) != len(datasets):