Skip to content

Commit

Permalink
kicking down the can
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Mar 4, 2024
1 parent 870c91f commit 5231f38
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 16 deletions.
9 changes: 1 addition & 8 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import re

import h5py
import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers as Kopt
Expand All @@ -16,12 +15,6 @@

import n3fit.backends.keras_backend.operations as op

# Check the TF version to check if legacy-mode is needed (TF < 2.2)
tf_version = tf.__version__.split(".")
if int(tf_version[0]) == 2 and int(tf_version[1]) < 2:
raise NotImplementedError("n3fit needs TF > 2.2 in order to work")


# We need a function to transform tensors to numpy/python primitives
# which is not part of the official TF interface and can change with the version
if hasattr(tf_utils, "to_numpy_or_python_type"):
Expand Down Expand Up @@ -484,7 +477,7 @@ def set_layer_replica_weights(layer, weights, i_replica: int):
if is_stacked_single_replicas(layer):
layer.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights)
return

full_weights = [w.numpy() for w in layer.weights]
for w_old, w_new in zip(full_weights, weights):
w_old[i_replica : i_replica + 1] = w_new
Expand Down
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def call(self, inputs):
If the input already contains multiple replica outputs, it is equivalent
to applying each replica to its corresponding input.
"""
if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
# cast always
inputs = tf.cast(inputs, dtype=self.compute_dtype)

outputs = self.matmul(inputs)

Expand Down
5 changes: 3 additions & 2 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import numpy.typing as npt
import tensorflow as tf
import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda as keras_Lambda
Expand Down Expand Up @@ -249,8 +250,8 @@ def concatenate(tensor_list, axis=-1, target_shape=None, name=None):
Concatenates a list of numbers or tensor into a bigger tensor
If the target shape is given, the output is reshaped to said shape
"""
concatenated_tensor = tf.concat(tensor_list, axis, name=name)
if target_shape:
concatenated_tensor = keras.ops.concatenate(tensor_list, axis=axis)
if target_shape is not None:
return K.reshape(concatenated_tensor, target_shape)
else:
return concatenated_tensor
Expand Down
3 changes: 0 additions & 3 deletions n3fit/src/n3fit/layers/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def __init__(self, fktable_data, fktable_arr, operation_name, nfl=14, **kwargs):
self.operation = op.c_to_py_fun(operation_name)
self.output_dim = self.fktables[0].shape[0]

def compute_output_shape(self, input_shape):
return (self.output_dim, None)

# Overridables
@abstractmethod
def gen_mask(self, basis):
Expand Down
2 changes: 1 addition & 1 deletion n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def compute_unnormalized_pdf(x):
# add batch and flavor dimensions
photon_integrals = op.batchit(op.batchit(photons.integral))
else:
photon_integrals = np.zeros((1, num_replicas, 1))
photon_integrals = op.numpy_to_tensor(np.zeros((1, num_replicas, 1)))

PDFs_normalized = sumrule_layer(
{
Expand Down

0 comments on commit 5231f38

Please sign in to comment.