Skip to content

Commit

Permalink
Merge branch 'XanderJC:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DrShushen authored Oct 2, 2024
2 parents a8bed6e + 13b0690 commit 7764b08
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions medkit/domains/icu.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def __init__(self, y_dim=2):
VAE_config = {
"latent_size": 10,
"hidden_units": 100,
"lr": 1e-4,
"lr": 1e-5,
"hidden_layers": 3,
"adam_betas": (0.9, 0.9),
"epochs": 10,
"epochs": 200,
}

self.env_config_dict = {
Expand Down
2 changes: 1 addition & 1 deletion medkit/domains/ward.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, y_dim=2):
"lr": 1e-4,
"hidden_layers": 3,
"adam_betas": (0.9, 0.9),
"epochs": 10,
"epochs": 200,
}

self.env_config_dict = {
Expand Down
2 changes: 1 addition & 1 deletion medkit/environments/SequentialVAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def step(self, action):
action = action.reshape((1, 1))
action_one_hot = F.one_hot(action, self.domain.y_dim)

x = torch.cat((self.prev_latent, action_one_hot, self.static), 2)
x = torch.cat((self.prev_obs, action_one_hot, self.static.expand(1, 1, -1)), 2)

out, (self.hn, self.cn) = self.model.lstm(x, (self.hn, self.cn))

Expand Down
2 changes: 1 addition & 1 deletion medkit/environments/TForce.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def step(self, action):
action = action.reshape((1, 1))
action_one_hot = F.one_hot(action, self.domain.y_dim)

x = torch.cat((self.prev_obs, action_one_hot, self.static), 2)
x = torch.cat((self.prev_obs, action_one_hot, self.static.expand(1, 1, -1)), 2)

out, (self.hn, self.cn) = self.model.lstm(x, (self.hn, self.cn))

Expand Down
Binary file modified medkit/environments/saved_models/icu_8_statespace.pth
Binary file not shown.
Binary file modified medkit/initialisers/saved_models/icu_VAE.pth
Binary file not shown.
Binary file modified medkit/initialisers/saved_models/ward_VAE.pth
Binary file not shown.

0 comments on commit 7764b08

Please sign in to comment.