diff --git a/medkit/domains/icu.py b/medkit/domains/icu.py index 28f2f70..544e08b 100644 --- a/medkit/domains/icu.py +++ b/medkit/domains/icu.py @@ -82,10 +82,10 @@ def __init__(self, y_dim=2): VAE_config = { "latent_size": 10, "hidden_units": 100, - "lr": 1e-4, + "lr": 1e-5, "hidden_layers": 3, "adam_betas": (0.9, 0.9), - "epochs": 10, + "epochs": 200, } self.env_config_dict = { diff --git a/medkit/domains/ward.py b/medkit/domains/ward.py index f048aa5..0ed0487 100644 --- a/medkit/domains/ward.py +++ b/medkit/domains/ward.py @@ -86,7 +86,7 @@ def __init__(self, y_dim=2): "lr": 1e-4, "hidden_layers": 3, "adam_betas": (0.9, 0.9), - "epochs": 10, + "epochs": 200, } self.env_config_dict = { diff --git a/medkit/environments/SequentialVAE.py b/medkit/environments/SequentialVAE.py index 6fb2f2e..7eec3b7 100644 --- a/medkit/environments/SequentialVAE.py +++ b/medkit/environments/SequentialVAE.py @@ -198,7 +198,7 @@ def step(self, action): action = action.reshape((1, 1)) action_one_hot = F.one_hot(action, self.domain.y_dim) - x = torch.cat((self.prev_latent, action_one_hot, self.static), 2) + x = torch.cat((self.prev_obs, action_one_hot, self.static.expand(1, 1, -1)), 2) out, (self.hn, self.cn) = self.model.lstm(x, (self.hn, self.cn)) diff --git a/medkit/environments/TForce.py b/medkit/environments/TForce.py index e694736..117b5ef 100644 --- a/medkit/environments/TForce.py +++ b/medkit/environments/TForce.py @@ -102,7 +102,7 @@ def step(self, action): action = action.reshape((1, 1)) action_one_hot = F.one_hot(action, self.domain.y_dim) - x = torch.cat((self.prev_obs, action_one_hot, self.static), 2) + x = torch.cat((self.prev_obs, action_one_hot, self.static.expand(1, 1, -1)), 2) out, (self.hn, self.cn) = self.model.lstm(x, (self.hn, self.cn)) diff --git a/medkit/environments/saved_models/icu_8_statespace.pth b/medkit/environments/saved_models/icu_8_statespace.pth index 2cbe63c..ed4ecdf 100644 Binary files a/medkit/environments/saved_models/icu_8_statespace.pth and b/medkit/environments/saved_models/icu_8_statespace.pth differ diff --git a/medkit/initialisers/saved_models/icu_VAE.pth b/medkit/initialisers/saved_models/icu_VAE.pth index 0e81e57..7073bad 100644 Binary files a/medkit/initialisers/saved_models/icu_VAE.pth and b/medkit/initialisers/saved_models/icu_VAE.pth differ diff --git a/medkit/initialisers/saved_models/ward_VAE.pth b/medkit/initialisers/saved_models/ward_VAE.pth index d1f2f7a..4fb5629 100644 Binary files a/medkit/initialisers/saved_models/ward_VAE.pth and b/medkit/initialisers/saved_models/ward_VAE.pth differ