diff --git a/Foggy_CycleGAN.ipynb b/Foggy_CycleGAN.ipynb index 911b10e..c912cce 100644 --- a/Foggy_CycleGAN.ipynb +++ b/Foggy_CycleGAN.ipynb @@ -69,16 +69,10 @@ "source": [ "import sys\n", "colab = 'google.colab' in sys.modules\n", - "if colab:\n", - " # noinspection PyBroadException\n", - " try:\n", - " %tensorflow_version 2.x\n", - " except Exception:\n", - " pass\n", "import tensorflow as tf" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -94,8 +88,8 @@ "# noinspection PyUnresolvedReferences\n", "print(tf.__version__)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -112,8 +106,8 @@ "\n", "tfds.disable_progress_bar()" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -142,8 +136,8 @@ " os.chdir(project_dir)\n", " print(\"Done.\")" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -167,8 +161,8 @@ "IMG_WIDTH = 256\n", "IMG_HEIGHT = 256" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -180,13 +174,11 @@ "source": [ "project_label = \"\" #@param {type:\"string\"}" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "mount_path = None #to suppress warnings\n", "drive_project_path = None\n", @@ -205,7 +197,9 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -219,8 +213,8 @@ "if colab:\n", " !sh $PROJECT_DIR/copy_dataset.sh" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -235,13 +229,11 @@ "source": [ "test_split = 0.2 #@param {type:\"slider\", min:0.05, max:0.95, step:0.05}" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "from lib.dataset import DatasetInitializer\n", "\n", @@ -257,7 +249,9 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -284,8 +278,8 @@ "OUTPUT_CHANNELS = 3\n", "models_builder = ModelsBuilder()" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -297,6 +291,8 @@ "source": [ "use_transmission_map = False #@param{type: \"boolean\"}\n", "use_gauss_filter = False #@param{type: \"boolean\"}\n", + "if use_gauss_filter and not use_transmission_map:\n", + " raise Exception(\"Gauss filter requires transmission map\")\n", "use_resize_conv = False #@param{type: \"boolean\"}\n", "\n", "generator_clear2fog = models_builder.build_generator(use_transmission_map=use_transmission_map,\n", @@ -304,8 +300,8 @@ " use_resize_conv=use_resize_conv)\n", "generator_fog2clear = models_builder.build_generator(use_transmission_map=False)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -320,8 +316,8 @@ "source": [ "tf.keras.utils.plot_model(generator_clear2fog, show_shapes=True, dpi=64, to_file='generator_clear2fog.png');" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -336,8 +332,8 @@ "source": [ "tf.keras.utils.plot_model(generator_fog2clear, show_shapes=True, dpi=64, to_file='generator_fog2clear.png');" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -364,8 +360,8 @@ "discriminator_fog = models_builder.build_discriminator(use_intensity=use_intensity_for_fog_discriminator)\n", "discriminator_clear = models_builder.build_discriminator(use_intensity=False)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -378,8 +374,8 @@ "tf.keras.utils.plot_model(discriminator_fog, show_shapes=True, dpi=64, to_file=\"discriminator_fog.png\");\n", "tf.keras.utils.plot_model(discriminator_clear, show_shapes=True, dpi=64, to_file=\"discriminator_clear.png\");" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -407,8 +403,8 @@ "else:\n", " weights_path = \"./weights/\"" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -424,8 +420,8 @@ "\n", "trainer.configure_checkpoint(weights_path = weights_path, load_optimizers=False)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -442,8 +438,8 @@ "for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n", " plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -460,8 +456,8 @@ "for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n", " plot_discriminators_predictions(discriminator_clear, clear, discriminator_fog, fog, use_intensity_for_fog_discriminator)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -486,8 +482,8 @@ "source": [ "use_tensorboard = True #@param{type:\"boolean\"}" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -512,8 +508,8 @@ " else:\n", " print(url)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -530,8 +526,8 @@ " trainer.image_log_path = os.path.join(drive_project_path,\"image_logs/\")\n", " trainer.config_path = os.path.join(drive_project_path,\"trainer_config.json\")" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -546,8 +542,8 @@ "source": [ "trainer.load_config()" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", @@ -576,8 +572,8 @@ " use_intensity_for_fog_discriminator=use_intensity_for_fog_discriminator\n", ")" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -605,13 +601,11 @@ "for clear, fog in zip(test_clear.take(5), test_fog.take(5)):\n", " plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)" ], - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "for clear, fog in zip(sample_clear, sample_fog):\n", " plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)" @@ -621,12 +615,12 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "from lib.plot import plot_clear2fog_intensity\n", "from matplotlib import pyplot as plt\n", @@ -651,12 +645,12 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "if colab:\n", " !cd ./intensity; zip /content/intensity.zip *" @@ -666,7 +660,9 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -682,8 +678,6 @@ }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "from lib.plot import plot_clear2fog_intensity\n", "from matplotlib import pyplot as plt\n", @@ -691,7 +685,7 @@ "intensity_path = './intensity/'\n", "from lib.tools import create_dir\n", "create_dir(intensity_path)\n", - "file_path = 'E:/Downloads/test-image.png'\n", + "file_path = './Downloads/test-image.png'\n", "\n", "image_clear = tf.io.decode_png(tf.io.read_file(file_path), channels=3)\n", "image_clear, _ = datasetInit.preprocess_image_test(image_clear, 0)\n", @@ -710,12 +704,12 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "outputs": [], "source": [ "if colab:\n", " !cd ./intensity; zip /content/intensity.zip *" @@ -725,7 +719,9 @@ "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "execution_count": null } ] -} \ No newline at end of file +} diff --git a/README.md b/README.md index a2e8f42..f0ad89e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +> [!NOTE] +> November 2024: New Pre-trained Models are available, check the [Pre-trained Models](#pre-trained-models) section. + # Foggy-CycleGAN

@@ -9,6 +12,13 @@ This project is the implementation for my Computer Science MSc thesis in the Uni Dissertation: [PDF] Simulating Weather Conditions on Digital Images (Debrecen, 2020). +# Table of Content +- [Description](#description) +- [Code](#code) +- [Notebook](#notebook) +- [Results](#results) +- [Pre-trained Models](#pre-trained-models) + ## Description **Foggy-CycleGAN** is a CycleGAN model trained to synthesize fog on clear images. More details in the dissertation above. @@ -16,30 +26,57 @@ Dissertation: ## Code The full source code is available under GPL-3.0 License in my Github repository ghaiszaher/Foggy-CycleGAN -## Pre-trained Models -A version of pre-trained models used in the thesis can be found [here](https://drive.google.com/drive/folders/1QKsiaGkMFvtGcp072IG57MfY1o_D-L3k?usp=sharing). - ## Notebook Open In Colab A Jupyter Notebook file Foggy_CycleGAN.ipynb is available in the repository. - ## Results +(as of June 2020)

- +

- +

- +

- +

-
+
© Ghais Zaher 2020
+ +## Pre-trained Models +As previous pre-trained models are no longer compatible with newer Keras/Tensorflow versions, I have retrained the model and made the new weights available to download. + +Each of the following models was trained in Google Colab using the same dataset, the parameters for building the models and number of trained epochs are a bit different: +
+ +| Model | Trained Epochs | Config | +|-------------------------------------------------------------------------------------------------------------|----------------|-------------------------------------------------------------------------------------| +| [2020-06 (legacy)](https://drive.google.com/drive/folders/1QKsiaGkMFvtGcp072IG57MfY1o_D-L3k?usp=sharing) | 145 | `use_transmission_map=False`
`use_gauss_filter=False`
`use_resize_conv=False` | +| [2024-11-17-rev1-000](https://drive.google.com/drive/folders/1--W53NNrVxS5pvrf8jDKCRmg4h4vD5lx?usp=sharing) | 522 | `use_transmission_map=False`
`use_gauss_filter=False`
`use_resize_conv=False` | +| [2024-11-17-rev2-110](https://drive.google.com/drive/folders/1rQ7jmsv63uv6v45IVZmZ8w9CVktqJAfn?usp=sharing) | 100 | `use_transmission_map=True`
`use_gauss_filter=True`
`use_resize_conv=False` | +| [2024-11-17-rev3-111](https://drive.google.com/drive/folders/1-0-z7KTMXTrwwUdeJtkUOBCWkwD6behO?usp=sharing) | 103 | `use_transmission_map=True`
`use_gauss_filter=True`
`use_resize_conv=True` | +| [2024-11-17-rev4-001](https://drive.google.com/drive/folders/1hDxJtU0agbnPO2XrrPo26RQJKOePa6WX?usp=sharing) | 39 | `use_transmission_map=False`
`use_gauss_filter=False`
`use_resize_conv=True` | + +
+ +### Results +The results of the new models are similar to the previous ones, here are some samples: +
+ +| Clear | 2024-11-17-rev1-000 | 2024-11-17-rev2-110 | 2024-11-17-rev3-111 | 2024-11-17-rev4-001 | +|---------------------------------------------------------|------------------------------------------------------------|------------------------------------------------------------|------------------------------------------------------------|------------------------------------------------------------| +| | | | | | +| | | | | | +| | | | | | +| | | | | | +| | | | | | + +
diff --git a/discriminator_clear.png b/discriminator_clear.png index 703c65f..cc75df7 100644 Binary files a/discriminator_clear.png and b/discriminator_clear.png differ diff --git a/discriminator_fog.png b/discriminator_fog.png index dfb997d..cc75df7 100644 Binary files a/discriminator_fog.png and b/discriminator_fog.png differ diff --git a/generator_clear2fog.png b/generator_clear2fog.png index 2d4f395..dbb5b81 100644 Binary files a/generator_clear2fog.png and b/generator_clear2fog.png differ diff --git a/generator_fog2clear.png b/generator_fog2clear.png index af43abd..dbb5b81 100644 Binary files a/generator_fog2clear.png and b/generator_fog2clear.png differ diff --git a/images/result-animated-01.gif b/images/results/2020-06/result-animated-01.gif similarity index 100% rename from images/result-animated-01.gif rename to images/results/2020-06/result-animated-01.gif diff --git a/images/result-sample-0.2.jpg b/images/results/2020-06/result-sample-0.2.jpg similarity index 100% rename from images/result-sample-0.2.jpg rename to images/results/2020-06/result-sample-0.2.jpg diff --git a/images/result-sample-0.25.jpg b/images/results/2020-06/result-sample-0.25.jpg similarity index 100% rename from images/result-sample-0.25.jpg rename to images/results/2020-06/result-sample-0.25.jpg diff --git a/images/result-sample-0.3.jpg b/images/results/2020-06/result-sample-0.3.jpg similarity index 100% rename from images/result-sample-0.3.jpg rename to images/results/2020-06/result-sample-0.3.jpg diff --git a/images/results/2024-11-17/clear/sample1.jpg b/images/results/2024-11-17/clear/sample1.jpg new file mode 100644 index 0000000..bdf4f95 Binary files /dev/null and b/images/results/2024-11-17/clear/sample1.jpg differ diff --git a/images/results/2024-11-17/clear/sample2.jpg b/images/results/2024-11-17/clear/sample2.jpg new file mode 100644 index 0000000..c75387e Binary files /dev/null and b/images/results/2024-11-17/clear/sample2.jpg differ diff --git a/images/results/2024-11-17/clear/sample3.jpg b/images/results/2024-11-17/clear/sample3.jpg new file mode 100644 index 0000000..8b27ab1 Binary files /dev/null and b/images/results/2024-11-17/clear/sample3.jpg differ diff --git a/images/results/2024-11-17/clear/sample4.jpg b/images/results/2024-11-17/clear/sample4.jpg new file mode 100644 index 0000000..166d807 Binary files /dev/null and b/images/results/2024-11-17/clear/sample4.jpg differ diff --git a/images/results/2024-11-17/clear/sample5.jpg b/images/results/2024-11-17/clear/sample5.jpg new file mode 100644 index 0000000..be554b8 Binary files /dev/null and b/images/results/2024-11-17/clear/sample5.jpg differ diff --git a/images/results/2024-11-17/rev1-000/sample1.gif b/images/results/2024-11-17/rev1-000/sample1.gif new file mode 100644 index 0000000..e0d524b Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample1.gif differ diff --git a/images/results/2024-11-17/rev1-000/sample2.gif b/images/results/2024-11-17/rev1-000/sample2.gif new file mode 100644 index 0000000..8d22210 Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample2.gif differ diff --git a/images/results/2024-11-17/rev1-000/sample3.gif b/images/results/2024-11-17/rev1-000/sample3.gif new file mode 100644 index 0000000..f334df0 Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample3.gif differ diff --git a/images/results/2024-11-17/rev1-000/sample4.gif b/images/results/2024-11-17/rev1-000/sample4.gif new file mode 100644 index 0000000..a29ca18 Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample4.gif differ diff --git a/images/results/2024-11-17/rev1-000/sample5.gif b/images/results/2024-11-17/rev1-000/sample5.gif new file mode 100644 index 0000000..cef331d Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample5.gif differ diff --git a/images/results/2024-11-17/rev2-110/sample1.gif b/images/results/2024-11-17/rev2-110/sample1.gif new file mode 100644 index 0000000..75ffa2f Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample1.gif differ diff --git a/images/results/2024-11-17/rev2-110/sample2.gif b/images/results/2024-11-17/rev2-110/sample2.gif new file mode 100644 index 0000000..5622d61 Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample2.gif differ diff --git a/images/results/2024-11-17/rev2-110/sample3.gif b/images/results/2024-11-17/rev2-110/sample3.gif new file mode 100644 index 0000000..067d7da Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample3.gif differ diff --git a/images/results/2024-11-17/rev2-110/sample4.gif b/images/results/2024-11-17/rev2-110/sample4.gif new file mode 100644 index 0000000..33e5d48 Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample4.gif differ diff --git a/images/results/2024-11-17/rev2-110/sample5.gif b/images/results/2024-11-17/rev2-110/sample5.gif new file mode 100644 index 0000000..45dcd74 Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample5.gif differ diff --git a/images/results/2024-11-17/rev3-111/sample1.gif b/images/results/2024-11-17/rev3-111/sample1.gif new file mode 100644 index 0000000..1815cdb Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample1.gif differ diff --git a/images/results/2024-11-17/rev3-111/sample2.gif b/images/results/2024-11-17/rev3-111/sample2.gif new file mode 100644 index 0000000..13dfdd1 Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample2.gif differ diff --git a/images/results/2024-11-17/rev3-111/sample3.gif b/images/results/2024-11-17/rev3-111/sample3.gif new file mode 100644 index 0000000..ea0211f Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample3.gif differ diff --git a/images/results/2024-11-17/rev3-111/sample4.gif b/images/results/2024-11-17/rev3-111/sample4.gif new file mode 100644 index 0000000..400c620 Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample4.gif differ diff --git a/images/results/2024-11-17/rev3-111/sample5.gif b/images/results/2024-11-17/rev3-111/sample5.gif new file mode 100644 index 0000000..bbe774d Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample5.gif differ diff --git a/images/results/2024-11-17/rev4-001/sample1.gif b/images/results/2024-11-17/rev4-001/sample1.gif new file mode 100644 index 0000000..48f93c7 Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample1.gif differ diff --git a/images/results/2024-11-17/rev4-001/sample2.gif b/images/results/2024-11-17/rev4-001/sample2.gif new file mode 100644 index 0000000..7536157 Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample2.gif differ diff --git a/images/results/2024-11-17/rev4-001/sample3.gif b/images/results/2024-11-17/rev4-001/sample3.gif new file mode 100644 index 0000000..155b5cf Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample3.gif differ diff --git a/images/results/2024-11-17/rev4-001/sample4.gif b/images/results/2024-11-17/rev4-001/sample4.gif new file mode 100644 index 0000000..141e188 Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample4.gif differ diff --git a/images/results/2024-11-17/rev4-001/sample5.gif b/images/results/2024-11-17/rev4-001/sample5.gif new file mode 100644 index 0000000..fbc77fd Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample5.gif differ diff --git a/lib/gauss.py b/lib/gauss.py index 0ed03f9..ec0d9d9 100644 --- a/lib/gauss.py +++ b/lib/gauss.py @@ -1,12 +1,11 @@ def gauss_blur_model(input_shape, kernel_size=19, sigma=5, **kwargs): import tensorflow as tf import numpy as np + def matlab_style_gauss2D(shape=(3, 3), sigma=0.5): """ 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma]) - #https://stackoverflow.com/questions/55643675/how-do-i-implement-gaussian-blurring-layer-in-keras - #https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python/17201686#17201686 """ m, n = [(ss - 1.) / 2. for ss in shape] y, x = np.ogrid[-m:m + 1, -n:n + 1] @@ -18,10 +17,8 @@ def matlab_style_gauss2D(shape=(3, 3), sigma=0.5): return h class SymmetricPadding2D(tf.keras.layers.Layer): - # Source: https://stackoverflow.com/a/55210905/11394663 - def __init__(self, output_dim, padding=(1, 1), + def __init__(self, padding=(1, 1), data_format="channels_last", **kwargs): - self.output_dim = output_dim self.data_format = data_format self.padding = padding super(SymmetricPadding2D, self).__init__(**kwargs) @@ -30,22 +27,30 @@ def build(self, input_shape): super(SymmetricPadding2D, self).build(input_shape) def call(self, inputs, **kwargs): - if self.data_format is "channels_last": - # (batch, depth, rows, cols, channels) - pad = [[0, 0]] + [[i, i] for i in self.padding] + [[0, 0]] - # elif self.data_format is "channels_first": + if self.data_format == "channels_last": + # (batch, rows, cols, channels) + pad = [[0, 0]] + [[p, p] for p in self.padding] + [[0, 0]] else: - # (batch, channels, depth, rows, cols) - pad = [[0, 0], [0, 0]] + [[i, i] for i in self.padding] + # (batch, channels, rows, cols) + pad = [[0, 0], [0, 0]] + [[p, p] for p in self.padding] paddings = tf.constant(pad) - out = tf.pad(inputs, paddings, "REFLECT") - return out + return tf.pad(inputs, paddings, "REFLECT") def compute_output_shape(self, input_shape): - return input_shape[0], self.output_dim + if self.data_format == "channels_last": + return (input_shape[0], + input_shape[1] + 2 * self.padding[0], + input_shape[2] + 2 * self.padding[1], + input_shape[3]) + else: + return (input_shape[0], + input_shape[1], + input_shape[2] + 2 * self.padding[0], + input_shape[3] + 2 * self.padding[1]) if kernel_size % 2 == 0: raise Exception("kernel size should be an odd number") + gauss_inputs = tf.keras.layers.Input(shape=input_shape) kernel_weights = matlab_style_gauss2D(shape=(kernel_size, kernel_size), sigma=sigma) @@ -53,12 +58,18 @@ def compute_output_shape(self, input_shape): kernel_weights = np.expand_dims(kernel_weights, axis=-1) kernel_weights = np.repeat(kernel_weights, in_channels, axis=-1) # apply the same filter on all the input channels kernel_weights = np.expand_dims(kernel_weights, axis=-1) # for shape compatibility reasons + gauss_layer = tf.keras.layers.DepthwiseConv2D(kernel_size, use_bias=False, padding='valid') p = (kernel_size - 1) // 2 - # noinspection PyCallingNonCallable - x = SymmetricPadding2D(0, padding=[p, p])(gauss_inputs) + + # Apply symmetric padding + x = SymmetricPadding2D(padding=[p, p])(gauss_inputs) + + # Ensure the input to DepthwiseConv2D has the correct shape x = gauss_layer(x) - ######################## + + # Set the weights for the gaussian filter gauss_layer.set_weights([kernel_weights]) gauss_layer.trainable = False + return tf.keras.Model(inputs=gauss_inputs, outputs=x, **kwargs)