-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlstm_ae_toy.py
56 lines (42 loc) · 1.89 KB
/
lstm_ae_toy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import numpy as np
import matplotlib.pyplot as plt
from logic.Utils import parse_args
from logic.Data_Generators import load_syntethic_data, generate_syntethic_data
from logic.LSTMS import LSTM_AE
from logic.Trainers import Basic_Trainer, kfolds_train
def plot_signal_vs_time():
dataset = generate_syntethic_data()
plt.plot(dataset[0])
plt.show()
def find_best_hyperparams_and_reconstruct_syntethic_data(args):
args.model = 'LSTM_AE'
dataset, testset = load_syntethic_data()
best_model, _ = kfolds_train(args, dataset, tune_hyperparams=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)
test_samples = [next(iter(test_loader))[0].unsqueeze(-1), next(iter(test_loader))[0].unsqueeze(-1)]
test_samples = torch.tensor(np.array(test_samples))
test_samples_reconstruction = best_model(test_samples)
fig, ax = plt.subplots(1, 2)
for i in range(2):
ax[i].set_title(f"Sample {i+1}")
ax[i].plot(test_samples[i].detach().numpy(), linewidth=2.5)
ax[i].plot(test_samples_reconstruction[i].detach().numpy(), linewidth=5, alpha=0.5)
ax[i].legend(['Original', 'Reconstruction'], loc='upper right')
plt.show()
def main():
args = parse_args()
if args.function == 'plot_signal_vs_time':
#called by:
'''
python3 lstm_ae_toy.py --function plot_signal_vs_time
'''
plot_signal_vs_time()
elif args.function == 'find_best_hyperparams_and_reconstruct_syntethic_data':
#called by:
'''
python3 lstm_ae_toy.py --function find_best_hyperparams_and_reconstruct_syntethic_data --model LSTM_AE --input_size 1 --hidden_size 8 16 32 --batch_size 128 --epochs 100 --learning_rate 0.1 0.01 0.001 --gradient_clipping 1 2 5
'''
find_best_hyperparams_and_reconstruct_syntethic_data(args)
if __name__ == '__main__':
main()