-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerator.py
81 lines (74 loc) · 2.68 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import h5py
import numpy as np
import numpy.typing as npt
from pathlib import Path
from sklearn.model_selection import train_test_split
from tensorflow import data
from typing import List, Tuple
class RegionETGenerator:
def __init__(
self, train_size: float = 0.5, val_size: float = 0.1, test_size: float = 0.4
):
self.train_size = train_size
self.val_size = val_size
self.test_size = test_size
self.random_state = 42
def get_generator(
self,
X: npt.NDArray,
y: npt.NDArray,
batch_size: int,
drop_reminder: bool = False,
) -> data.Dataset:
dataset = data.Dataset.from_tensor_slices((X, y))
return (
dataset.shuffle(210 * batch_size)
.batch(batch_size, drop_remainder=drop_reminder)
.prefetch(data.AUTOTUNE)
)
def get_data(self, datasets_paths: List[Path]) -> npt.NDArray:
inputs = []
for dataset_path in datasets_paths:
inputs.append(
h5py.File(dataset_path, "r")["CaloRegions"][:].astype("float32")
)
X = np.concatenate(inputs)
X = np.reshape(X, (-1, 18, 14, 1))
return X
def get_data_split(
self, datasets_paths: List[Path]
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
X = self.get_data(datasets_paths)
X_train, X_test = train_test_split(
X, test_size=self.test_size, random_state=self.random_state
)
X_train, X_val = train_test_split(
X_train,
test_size=self.val_size / (self.val_size + self.train_size),
random_state=self.random_state,
)
return (X_train, X_val, X_test)
def get_benchmark(
self, datasets: dict, filter_acceptance=True
) -> Tuple[dict, list]:
signals = {}
acceptance = []
for dataset in datasets:
if not dataset["use"]:
continue
signal_name = dataset["name"]
for dataset_path in dataset["path"]:
X = h5py.File(dataset_path, "r")["CaloRegions"][:].astype("float32")
X = np.reshape(X, (-1, 18, 14, 1))
try:
flags = h5py.File(dataset_path, "r")["AcceptanceFlag"][:].astype(
"bool"
)
fraction = np.round(100 * sum(flags) / len(flags), 2)
except KeyError:
fraction = 100.0
if filter_acceptance:
X = X[flags]
signals[signal_name] = X
acceptance.append({"signal": signal_name, "acceptance": fraction})
return signals, acceptance