Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
remove extraneous layer names, abbreviate kernel divergence function name
  • Loading branch information
tr7200 authored Jan 30, 2021
1 parent cddd728 commit f6b6344
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@
bd_input = Input(shape=(39,))
as_input = Input(shape=(12,))

kernel_divergence_fn = lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (323 * 1.0)
krnl_dvrgnc_fn = lambda q, p, _: tfp.distributions.kl_divergence(q, p) / (323 * 1.0)


# outer chevron
sm_bd_combined = concatenate([sm_input, bd_input])
sm_bd_combined_out = tfp.layers.DenseFlipout(48, activation='relu', name = 'dense1', kernel_divergence_fn=kernel_divergence_fn)(sm_bd_combined )
sm_bd_combined_out = tfp.layers.DenseFlipout(48, activation='relu', kernel_divergence_fn=krnl_dvrgnc_fn)(sm_bd_combined)

as_and_sm_bd_combined = concatenate([sm_bd_combined_out, as_input])
as_and_sm_bd_combined_out = tfp.layers.DenseFlipout(3, activation='relu', name = 'dense2', kernel_divergence_fn=kernel_divergence_fn)(as_and_sm_bd_combined)
as_and_sm_bd_combined_out = tfp.layers.DenseFlipout(3, activation='relu', kernel_divergence_fn=krnl_dvrgnc_fn)(as_and_sm_bd_combined)
V_struct_1_out = tfp.layers.DistributionLambda(lambda t: tfp.distributions.Normal(loc=25 + t[..., :1],
validate_args=True,
allow_nan_stats=False,
Expand All @@ -78,8 +78,8 @@
# inner chevron
V1_out = V_struct_1([sm_input, bd_input, as_input])
V1_SM_BD_combined = concatenate([V1_out, bd_input, sm_input])
V1_SM_BD_combined_out1 = tfp.layers.DenseFlipout(37, activation='relu', name = 'dense3', kernel_divergence_fn=kernel_divergence_fn)(V1_SM_BD_combined)
V1_SM_BD_combined_out2 = tfp.layers.DenseFlipout(10, activation='relu', name = 'dense4', kernel_divergence_fn=kernel_divergence_fn)(V1_SM_BD_combined_out1)
V1_SM_BD_combined_out1 = tfp.layers.DenseFlipout(37, activation='relu', kernel_divergence_fn=krnl_dvrgnc_fn)(V1_SM_BD_combined)
V1_SM_BD_combined_out2 = tfp.layers.DenseFlipout(10, activation='relu', kernel_divergence_fn=krnl_dvrgnc_fn)(V1_SM_BD_combined_out1)
V2_out = tfp.layers.DistributionLambda(lambda t: tfp.distributions.Normal(loc=25 + t[..., :1],
validate_args=True,
allow_nan_stats=False,
Expand Down

0 comments on commit f6b6344

Please sign in to comment.