diff --git a/Foggy_CycleGAN.ipynb b/Foggy_CycleGAN.ipynb
index 911b10e..c912cce 100644
--- a/Foggy_CycleGAN.ipynb
+++ b/Foggy_CycleGAN.ipynb
@@ -69,16 +69,10 @@
"source": [
"import sys\n",
"colab = 'google.colab' in sys.modules\n",
- "if colab:\n",
- " # noinspection PyBroadException\n",
- " try:\n",
- " %tensorflow_version 2.x\n",
- " except Exception:\n",
- " pass\n",
"import tensorflow as tf"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -94,8 +88,8 @@
"# noinspection PyUnresolvedReferences\n",
"print(tf.__version__)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -112,8 +106,8 @@
"\n",
"tfds.disable_progress_bar()"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -142,8 +136,8 @@
" os.chdir(project_dir)\n",
" print(\"Done.\")"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -167,8 +161,8 @@
"IMG_WIDTH = 256\n",
"IMG_HEIGHT = 256"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -180,13 +174,11 @@
"source": [
"project_label = \"\" #@param {type:\"string\"}"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"mount_path = None #to suppress warnings\n",
"drive_project_path = None\n",
@@ -205,7 +197,9 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -219,8 +213,8 @@
"if colab:\n",
" !sh $PROJECT_DIR/copy_dataset.sh"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -235,13 +229,11 @@
"source": [
"test_split = 0.2 #@param {type:\"slider\", min:0.05, max:0.95, step:0.05}"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"from lib.dataset import DatasetInitializer\n",
"\n",
@@ -257,7 +249,9 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -284,8 +278,8 @@
"OUTPUT_CHANNELS = 3\n",
"models_builder = ModelsBuilder()"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -297,6 +291,8 @@
"source": [
"use_transmission_map = False #@param{type: \"boolean\"}\n",
"use_gauss_filter = False #@param{type: \"boolean\"}\n",
+ "if use_gauss_filter and not use_transmission_map:\n",
+ " raise Exception(\"Gauss filter requires transmission map\")\n",
"use_resize_conv = False #@param{type: \"boolean\"}\n",
"\n",
"generator_clear2fog = models_builder.build_generator(use_transmission_map=use_transmission_map,\n",
@@ -304,8 +300,8 @@
" use_resize_conv=use_resize_conv)\n",
"generator_fog2clear = models_builder.build_generator(use_transmission_map=False)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -320,8 +316,8 @@
"source": [
"tf.keras.utils.plot_model(generator_clear2fog, show_shapes=True, dpi=64, to_file='generator_clear2fog.png');"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -336,8 +332,8 @@
"source": [
"tf.keras.utils.plot_model(generator_fog2clear, show_shapes=True, dpi=64, to_file='generator_fog2clear.png');"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -364,8 +360,8 @@
"discriminator_fog = models_builder.build_discriminator(use_intensity=use_intensity_for_fog_discriminator)\n",
"discriminator_clear = models_builder.build_discriminator(use_intensity=False)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -378,8 +374,8 @@
"tf.keras.utils.plot_model(discriminator_fog, show_shapes=True, dpi=64, to_file=\"discriminator_fog.png\");\n",
"tf.keras.utils.plot_model(discriminator_clear, show_shapes=True, dpi=64, to_file=\"discriminator_clear.png\");"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -407,8 +403,8 @@
"else:\n",
" weights_path = \"./weights/\""
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -424,8 +420,8 @@
"\n",
"trainer.configure_checkpoint(weights_path = weights_path, load_optimizers=False)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -442,8 +438,8 @@
"for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n",
" plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -460,8 +456,8 @@
"for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n",
" plot_discriminators_predictions(discriminator_clear, clear, discriminator_fog, fog, use_intensity_for_fog_discriminator)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -486,8 +482,8 @@
"source": [
"use_tensorboard = True #@param{type:\"boolean\"}"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -512,8 +508,8 @@
" else:\n",
" print(url)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -530,8 +526,8 @@
" trainer.image_log_path = os.path.join(drive_project_path,\"image_logs/\")\n",
" trainer.config_path = os.path.join(drive_project_path,\"trainer_config.json\")"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -546,8 +542,8 @@
"source": [
"trainer.load_config()"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
@@ -576,8 +572,8 @@
" use_intensity_for_fog_discriminator=use_intensity_for_fog_discriminator\n",
")"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -605,13 +601,11 @@
"for clear, fog in zip(test_clear.take(5), test_fog.take(5)):\n",
" plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
],
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"for clear, fog in zip(sample_clear, sample_fog):\n",
" plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
@@ -621,12 +615,12 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"from lib.plot import plot_clear2fog_intensity\n",
"from matplotlib import pyplot as plt\n",
@@ -651,12 +645,12 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"if colab:\n",
" !cd ./intensity; zip /content/intensity.zip *"
@@ -666,7 +660,9 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "markdown",
@@ -682,8 +678,6 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"from lib.plot import plot_clear2fog_intensity\n",
"from matplotlib import pyplot as plt\n",
@@ -691,7 +685,7 @@
"intensity_path = './intensity/'\n",
"from lib.tools import create_dir\n",
"create_dir(intensity_path)\n",
- "file_path = 'E:/Downloads/test-image.png'\n",
+ "file_path = './Downloads/test-image.png'\n",
"\n",
"image_clear = tf.io.decode_png(tf.io.read_file(file_path), channels=3)\n",
"image_clear, _ = datasetInit.preprocess_image_test(image_clear, 0)\n",
@@ -710,12 +704,12 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
},
{
"cell_type": "code",
- "execution_count": null,
- "outputs": [],
"source": [
"if colab:\n",
" !cd ./intensity; zip /content/intensity.zip *"
@@ -725,7 +719,9 @@
"pycharm": {
"name": "#%%\n"
}
- }
+ },
+ "outputs": [],
+ "execution_count": null
}
]
-}
\ No newline at end of file
+}
diff --git a/README.md b/README.md
index a2e8f42..f0ad89e 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,6 @@
+> [!NOTE]
+> November 2024: New Pre-trained Models are available, check the [Pre-trained Models](#pre-trained-models) section.
+
# Foggy-CycleGAN
@@ -9,6 +12,13 @@ This project is the implementation for my Computer Science MSc thesis in the Uni
Dissertation:
[PDF] Simulating Weather Conditions on Digital Images (Debrecen, 2020).
+# Table of Content
+- [Description](#description)
+- [Code](#code)
+- [Notebook](#notebook)
+- [Results](#results)
+- [Pre-trained Models](#pre-trained-models)
+
## Description
**Foggy-CycleGAN** is a
CycleGAN model trained to synthesize fog on clear images. More details in the dissertation above.
@@ -16,30 +26,57 @@ Dissertation:
## Code
The full source code is available under GPL-3.0 License in my Github repository ghaiszaher/Foggy-CycleGAN
-## Pre-trained Models
-A version of pre-trained models used in the thesis can be found [here](https://drive.google.com/drive/folders/1QKsiaGkMFvtGcp072IG57MfY1o_D-L3k?usp=sharing).
-
## Notebook
A Jupyter Notebook file Foggy_CycleGAN.ipynb is available in the repository.
-
## Results
+(as of June 2020)
-
+
-
+
-
+
-
+
-
+
© Ghais Zaher 2020
+
+## Pre-trained Models
+As previous pre-trained models are no longer compatible with newer Keras/Tensorflow versions, I have retrained the model and made the new weights available to download.
+
+Each of the following models was trained in Google Colab using the same dataset, the parameters for building the models and number of trained epochs are a bit different:
+
+
+| Model | Trained Epochs | Config |
+|-------------------------------------------------------------------------------------------------------------|----------------|-------------------------------------------------------------------------------------|
+| [2020-06 (legacy)](https://drive.google.com/drive/folders/1QKsiaGkMFvtGcp072IG57MfY1o_D-L3k?usp=sharing) | 145 | `use_transmission_map=False`
`use_gauss_filter=False`
`use_resize_conv=False` |
+| [2024-11-17-rev1-000](https://drive.google.com/drive/folders/1--W53NNrVxS5pvrf8jDKCRmg4h4vD5lx?usp=sharing) | 522 | `use_transmission_map=False`
`use_gauss_filter=False`
`use_resize_conv=False` |
+| [2024-11-17-rev2-110](https://drive.google.com/drive/folders/1rQ7jmsv63uv6v45IVZmZ8w9CVktqJAfn?usp=sharing) | 100 | `use_transmission_map=True`
`use_gauss_filter=True`
`use_resize_conv=False` |
+| [2024-11-17-rev3-111](https://drive.google.com/drive/folders/1-0-z7KTMXTrwwUdeJtkUOBCWkwD6behO?usp=sharing) | 103 | `use_transmission_map=True`
`use_gauss_filter=True`
`use_resize_conv=True` |
+| [2024-11-17-rev4-001](https://drive.google.com/drive/folders/1hDxJtU0agbnPO2XrrPo26RQJKOePa6WX?usp=sharing) | 39 | `use_transmission_map=False`
`use_gauss_filter=False`
`use_resize_conv=True` |
+
+
+
+### Results
+The results of the new models are similar to the previous ones, here are some samples:
+
+
+| Clear | 2024-11-17-rev1-000 | 2024-11-17-rev2-110 | 2024-11-17-rev3-111 | 2024-11-17-rev4-001 |
+|---------------------------------------------------------|------------------------------------------------------------|------------------------------------------------------------|------------------------------------------------------------|------------------------------------------------------------|
+|
|
|
|
|
|
+|
|
|
|
|
|
+|
|
|
|
|
|
+|
|
|
|
|
|
+|
|
|
|
|
|
+
+
diff --git a/discriminator_clear.png b/discriminator_clear.png
index 703c65f..cc75df7 100644
Binary files a/discriminator_clear.png and b/discriminator_clear.png differ
diff --git a/discriminator_fog.png b/discriminator_fog.png
index dfb997d..cc75df7 100644
Binary files a/discriminator_fog.png and b/discriminator_fog.png differ
diff --git a/generator_clear2fog.png b/generator_clear2fog.png
index 2d4f395..dbb5b81 100644
Binary files a/generator_clear2fog.png and b/generator_clear2fog.png differ
diff --git a/generator_fog2clear.png b/generator_fog2clear.png
index af43abd..dbb5b81 100644
Binary files a/generator_fog2clear.png and b/generator_fog2clear.png differ
diff --git a/images/result-animated-01.gif b/images/results/2020-06/result-animated-01.gif
similarity index 100%
rename from images/result-animated-01.gif
rename to images/results/2020-06/result-animated-01.gif
diff --git a/images/result-sample-0.2.jpg b/images/results/2020-06/result-sample-0.2.jpg
similarity index 100%
rename from images/result-sample-0.2.jpg
rename to images/results/2020-06/result-sample-0.2.jpg
diff --git a/images/result-sample-0.25.jpg b/images/results/2020-06/result-sample-0.25.jpg
similarity index 100%
rename from images/result-sample-0.25.jpg
rename to images/results/2020-06/result-sample-0.25.jpg
diff --git a/images/result-sample-0.3.jpg b/images/results/2020-06/result-sample-0.3.jpg
similarity index 100%
rename from images/result-sample-0.3.jpg
rename to images/results/2020-06/result-sample-0.3.jpg
diff --git a/images/results/2024-11-17/clear/sample1.jpg b/images/results/2024-11-17/clear/sample1.jpg
new file mode 100644
index 0000000..bdf4f95
Binary files /dev/null and b/images/results/2024-11-17/clear/sample1.jpg differ
diff --git a/images/results/2024-11-17/clear/sample2.jpg b/images/results/2024-11-17/clear/sample2.jpg
new file mode 100644
index 0000000..c75387e
Binary files /dev/null and b/images/results/2024-11-17/clear/sample2.jpg differ
diff --git a/images/results/2024-11-17/clear/sample3.jpg b/images/results/2024-11-17/clear/sample3.jpg
new file mode 100644
index 0000000..8b27ab1
Binary files /dev/null and b/images/results/2024-11-17/clear/sample3.jpg differ
diff --git a/images/results/2024-11-17/clear/sample4.jpg b/images/results/2024-11-17/clear/sample4.jpg
new file mode 100644
index 0000000..166d807
Binary files /dev/null and b/images/results/2024-11-17/clear/sample4.jpg differ
diff --git a/images/results/2024-11-17/clear/sample5.jpg b/images/results/2024-11-17/clear/sample5.jpg
new file mode 100644
index 0000000..be554b8
Binary files /dev/null and b/images/results/2024-11-17/clear/sample5.jpg differ
diff --git a/images/results/2024-11-17/rev1-000/sample1.gif b/images/results/2024-11-17/rev1-000/sample1.gif
new file mode 100644
index 0000000..e0d524b
Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample1.gif differ
diff --git a/images/results/2024-11-17/rev1-000/sample2.gif b/images/results/2024-11-17/rev1-000/sample2.gif
new file mode 100644
index 0000000..8d22210
Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample2.gif differ
diff --git a/images/results/2024-11-17/rev1-000/sample3.gif b/images/results/2024-11-17/rev1-000/sample3.gif
new file mode 100644
index 0000000..f334df0
Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample3.gif differ
diff --git a/images/results/2024-11-17/rev1-000/sample4.gif b/images/results/2024-11-17/rev1-000/sample4.gif
new file mode 100644
index 0000000..a29ca18
Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample4.gif differ
diff --git a/images/results/2024-11-17/rev1-000/sample5.gif b/images/results/2024-11-17/rev1-000/sample5.gif
new file mode 100644
index 0000000..cef331d
Binary files /dev/null and b/images/results/2024-11-17/rev1-000/sample5.gif differ
diff --git a/images/results/2024-11-17/rev2-110/sample1.gif b/images/results/2024-11-17/rev2-110/sample1.gif
new file mode 100644
index 0000000..75ffa2f
Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample1.gif differ
diff --git a/images/results/2024-11-17/rev2-110/sample2.gif b/images/results/2024-11-17/rev2-110/sample2.gif
new file mode 100644
index 0000000..5622d61
Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample2.gif differ
diff --git a/images/results/2024-11-17/rev2-110/sample3.gif b/images/results/2024-11-17/rev2-110/sample3.gif
new file mode 100644
index 0000000..067d7da
Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample3.gif differ
diff --git a/images/results/2024-11-17/rev2-110/sample4.gif b/images/results/2024-11-17/rev2-110/sample4.gif
new file mode 100644
index 0000000..33e5d48
Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample4.gif differ
diff --git a/images/results/2024-11-17/rev2-110/sample5.gif b/images/results/2024-11-17/rev2-110/sample5.gif
new file mode 100644
index 0000000..45dcd74
Binary files /dev/null and b/images/results/2024-11-17/rev2-110/sample5.gif differ
diff --git a/images/results/2024-11-17/rev3-111/sample1.gif b/images/results/2024-11-17/rev3-111/sample1.gif
new file mode 100644
index 0000000..1815cdb
Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample1.gif differ
diff --git a/images/results/2024-11-17/rev3-111/sample2.gif b/images/results/2024-11-17/rev3-111/sample2.gif
new file mode 100644
index 0000000..13dfdd1
Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample2.gif differ
diff --git a/images/results/2024-11-17/rev3-111/sample3.gif b/images/results/2024-11-17/rev3-111/sample3.gif
new file mode 100644
index 0000000..ea0211f
Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample3.gif differ
diff --git a/images/results/2024-11-17/rev3-111/sample4.gif b/images/results/2024-11-17/rev3-111/sample4.gif
new file mode 100644
index 0000000..400c620
Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample4.gif differ
diff --git a/images/results/2024-11-17/rev3-111/sample5.gif b/images/results/2024-11-17/rev3-111/sample5.gif
new file mode 100644
index 0000000..bbe774d
Binary files /dev/null and b/images/results/2024-11-17/rev3-111/sample5.gif differ
diff --git a/images/results/2024-11-17/rev4-001/sample1.gif b/images/results/2024-11-17/rev4-001/sample1.gif
new file mode 100644
index 0000000..48f93c7
Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample1.gif differ
diff --git a/images/results/2024-11-17/rev4-001/sample2.gif b/images/results/2024-11-17/rev4-001/sample2.gif
new file mode 100644
index 0000000..7536157
Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample2.gif differ
diff --git a/images/results/2024-11-17/rev4-001/sample3.gif b/images/results/2024-11-17/rev4-001/sample3.gif
new file mode 100644
index 0000000..155b5cf
Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample3.gif differ
diff --git a/images/results/2024-11-17/rev4-001/sample4.gif b/images/results/2024-11-17/rev4-001/sample4.gif
new file mode 100644
index 0000000..141e188
Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample4.gif differ
diff --git a/images/results/2024-11-17/rev4-001/sample5.gif b/images/results/2024-11-17/rev4-001/sample5.gif
new file mode 100644
index 0000000..fbc77fd
Binary files /dev/null and b/images/results/2024-11-17/rev4-001/sample5.gif differ
diff --git a/lib/gauss.py b/lib/gauss.py
index 0ed03f9..ec0d9d9 100644
--- a/lib/gauss.py
+++ b/lib/gauss.py
@@ -1,12 +1,11 @@
def gauss_blur_model(input_shape, kernel_size=19, sigma=5, **kwargs):
import tensorflow as tf
import numpy as np
+
def matlab_style_gauss2D(shape=(3, 3), sigma=0.5):
"""
2D gaussian mask - should give the same result as MATLAB's
fspecial('gaussian',[shape],[sigma])
- #https://stackoverflow.com/questions/55643675/how-do-i-implement-gaussian-blurring-layer-in-keras
- #https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python/17201686#17201686
"""
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
@@ -18,10 +17,8 @@ def matlab_style_gauss2D(shape=(3, 3), sigma=0.5):
return h
class SymmetricPadding2D(tf.keras.layers.Layer):
- # Source: https://stackoverflow.com/a/55210905/11394663
- def __init__(self, output_dim, padding=(1, 1),
+ def __init__(self, padding=(1, 1),
data_format="channels_last", **kwargs):
- self.output_dim = output_dim
self.data_format = data_format
self.padding = padding
super(SymmetricPadding2D, self).__init__(**kwargs)
@@ -30,22 +27,30 @@ def build(self, input_shape):
super(SymmetricPadding2D, self).build(input_shape)
def call(self, inputs, **kwargs):
- if self.data_format is "channels_last":
- # (batch, depth, rows, cols, channels)
- pad = [[0, 0]] + [[i, i] for i in self.padding] + [[0, 0]]
- # elif self.data_format is "channels_first":
+ if self.data_format == "channels_last":
+ # (batch, rows, cols, channels)
+ pad = [[0, 0]] + [[p, p] for p in self.padding] + [[0, 0]]
else:
- # (batch, channels, depth, rows, cols)
- pad = [[0, 0], [0, 0]] + [[i, i] for i in self.padding]
+ # (batch, channels, rows, cols)
+ pad = [[0, 0], [0, 0]] + [[p, p] for p in self.padding]
paddings = tf.constant(pad)
- out = tf.pad(inputs, paddings, "REFLECT")
- return out
+ return tf.pad(inputs, paddings, "REFLECT")
def compute_output_shape(self, input_shape):
- return input_shape[0], self.output_dim
+ if self.data_format == "channels_last":
+ return (input_shape[0],
+ input_shape[1] + 2 * self.padding[0],
+ input_shape[2] + 2 * self.padding[1],
+ input_shape[3])
+ else:
+ return (input_shape[0],
+ input_shape[1],
+ input_shape[2] + 2 * self.padding[0],
+ input_shape[3] + 2 * self.padding[1])
if kernel_size % 2 == 0:
raise Exception("kernel size should be an odd number")
+
gauss_inputs = tf.keras.layers.Input(shape=input_shape)
kernel_weights = matlab_style_gauss2D(shape=(kernel_size, kernel_size), sigma=sigma)
@@ -53,12 +58,18 @@ def compute_output_shape(self, input_shape):
kernel_weights = np.expand_dims(kernel_weights, axis=-1)
kernel_weights = np.repeat(kernel_weights, in_channels, axis=-1) # apply the same filter on all the input channels
kernel_weights = np.expand_dims(kernel_weights, axis=-1) # for shape compatibility reasons
+
gauss_layer = tf.keras.layers.DepthwiseConv2D(kernel_size, use_bias=False, padding='valid')
p = (kernel_size - 1) // 2
- # noinspection PyCallingNonCallable
- x = SymmetricPadding2D(0, padding=[p, p])(gauss_inputs)
+
+ # Apply symmetric padding
+ x = SymmetricPadding2D(padding=[p, p])(gauss_inputs)
+
+ # Ensure the input to DepthwiseConv2D has the correct shape
x = gauss_layer(x)
- ########################
+
+ # Set the weights for the gaussian filter
gauss_layer.set_weights([kernel_weights])
gauss_layer.trainable = False
+
return tf.keras.Model(inputs=gauss_inputs, outputs=x, **kwargs)