-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn.py
32 lines (26 loc) · 1.08 KB
/
nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
class TDA(nn.Module):
def __init__(self, time_window, scaling_factor):
super(TDA, self).__init__()
self.time_window = time_window
self.scaling_factor = scaling_factor
def weight(self, t):
return self.scaling_factor * torch.exp(-t / self.time_window)
def forward(self, x):
t = torch.arange(x.shape[2], dtype=torch.float32, device=x.device)
weights = self.scaling_factor * torch.exp(-t / self.time_window)
weights = weights.view(1, 1, -1).expand_as(x) # Reshape and expand weights to match the input tensor dimensions
weighted_x = x * weights
return weighted_x
class TDAClip(nn.Module):
def __init__(self, max_value, op=torch.mean):
super(TDAClip, self).__init__()
self.max_value = max_value
self.op = op
def forward(self, x):
clamped_x = torch.clamp(x, min=0, max=self.max_value)
if self.op:
return self.op(clamped_x, dim=2) # Compute the op along the time_window dimension
else:
return clamped_x