forked from camlab-ethz/poseidon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathera5.py
60 lines (46 loc) · 1.87 KB
/
era5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import h5py
from scOT.problems.base import BaseTimeDataset
class ERA5_UV(BaseTimeDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.N_max = (2010 - 1995) * 365 * 4 - self.max_num_time_steps + 1
self.N_val = (2007 - 2005) * 365 * 4 - self.max_num_time_steps + 1
self.N_test = (2010 - 2007) * 365 * 4 - self.max_num_time_steps + 1
self.resolution = 128
data_path = self.data_path + "/ERA5.h5"
data_path = self._move_to_local_scratch(data_path)
self.reader = h5py.File(data_path, "r")
self.keys = list(self.reader.keys())
self.constants = {
"time": 14.0 * 4.0,
"mean": torch.tensor([-0.05483689, 0.18707459]).unsqueeze(1).unsqueeze(1),
"std": torch.tensor([5.2594, 4.5301833]).unsqueeze(1).unsqueeze(1),
}
self.input_dim = 2
self.label_description = "[10U,10V]"
self.post_init()
def __getitem__(self, idx):
i, t, t1, t2 = self._idx_map(idx)
time = t / self.constants["time"]
input_tensors = list(
torch.from_numpy(
self.reader[self.keys[idx]][i + self.start + t1][:, :-1]
).type(torch.float32)
for idx in range(len(self.keys))
)
inputs = torch.stack(input_tensors, dim=0)
label_tensors = list(
torch.from_numpy(
self.reader[self.keys[idx]][i + self.start + t2][:, :-1]
).type(torch.float32)
for idx in range(len(self.keys))
)
label = torch.stack(label_tensors, dim=0)
inputs = (inputs - self.constants["mean"]) / self.constants["std"]
label = (label - self.constants["mean"]) / self.constants["std"]
return {
"pixel_values": inputs,
"labels": label,
"time": time,
}