Skip to content

Commit

Permalink
Replace old dense layer everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jan 11, 2024
1 parent c23a59f commit e0fc2e1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 17 deletions.
12 changes: 1 addition & 11 deletions n3fit/src/n3fit/backends/keras_backend/base_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def apply_dense(xinput):


layers = {
"multi_dense": (
"dense": (
MultiDense,
{
"input_shape": (1,),
Expand All @@ -134,16 +134,6 @@ def apply_dense(xinput):
"replica_input": True,
},
),
"dense": (
Dense,
{
"input_shape": (1,),
"kernel_initializer": "glorot_normal",
"units": 5,
"activation": "sigmoid",
"kernel_regularizer": None,
},
),
"dense_per_flavour": (
dense_per_flavour,
{
Expand Down
7 changes: 3 additions & 4 deletions n3fit/src/n3fit/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def check_initializer(initializer):
def check_layer_type_implemented(parameters):
"""Checks whether the layer_type is implemented"""
layer_type = parameters.get("layer_type")
implemented_types = ["dense", "multi_dense", "dense_per_flavour"]
implemented_types = ["dense", "dense_per_flavour"]
if layer_type not in implemented_types:
raise CheckError(
f"Layer type {layer_type} not implemented, must be one of {implemented_types}"
Expand Down Expand Up @@ -427,10 +427,9 @@ def check_fiatlux_pdfs_id(replicas, fiatlux):
f"Cannot generate a photon replica with id larger than the number of replicas of the PDFs set {luxset.name}:\nreplica id={max_id}, replicas of {luxset.name} = {pdfs_ids}"
)


@make_argcheck
def check_multireplica_qed(replicas, fiatlux):
if fiatlux is not None:
if len(replicas) > 1:
raise CheckError(
"At the moment, running a multireplica QED fits is not allowed."
)
raise CheckError("At the moment, running a multireplica QED fits is not allowed.")
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def initializer_generator(seed, i_layer):
# list_of_pdf_layers[d][r] is the layer at depth d for replica r
list_of_pdf_layers = []
for i_layer, (nodes_out, activation) in enumerate(zip(nodes_list, activations)):
if layer_type == "multi_dense":
if layer_type == "dense":
layers = base_layer_selector(
layer_type,
replica_seeds=replica_seeds,
Expand Down Expand Up @@ -777,7 +777,7 @@ def initializer_generator(seed, i_layer):
list_of_pdf_layers[-1] = [lambda x: concat(layer(x)) for layer in list_of_pdf_layers[-1]]

# Apply all layers to the input to create the models
if layer_type == "multi_dense":
if layer_type == "dense":
pdfs = x_input
for layer in list_of_pdf_layers:
pdfs = layer(pdfs)
Expand Down

0 comments on commit e0fc2e1

Please sign in to comment.