From 1cc984b5d481f391f18627b20e9245cca34b2e81 Mon Sep 17 00:00:00 2001 From: Mandlin Sarah Date: Wed, 4 Sep 2024 13:21:55 -0700 Subject: [PATCH] Improve utils/plots.py: Add image format support and enhance readability --- utils/plots.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/utils/plots.py b/utils/plots.py index e019941..706ca02 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -6,28 +6,31 @@ import numpy as np import matplotlib.pyplot as plt +def load_img(filename, debug=False, norm=True, resize=None): + def _imread(filename): + # Attempt to read image with cv2, if it fails, raise an IOError + img = cv2.imread(filename) + if img is None: + raise IOError(f"Cannot open image file {filename}") + return img -def load_img (filename, debug=False, norm=True, resize=None): - img = cv2.imread(filename) + img = _imread(filename) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if norm: img = img / 255. img = img.astype(np.float32) if debug: - print (img.shape, img.dtype, img.min(), img.max()) - + print(img.shape, img.dtype, img.min(), img.max()) if resize: img = cv2.resize(img, (resize[0], resize[1])) - return img - -def plot_all (images, axis='off', figsize=(16, 8)): - +def plot_all(images, axis='off', figsize=(16, 8)): fig = plt.figure(figsize=figsize, dpi=80) nplots = len(images) - for i in range(nplots): - plt.subplot(1,nplots,i+1) - plt.axis(axis) - plt.imshow(images[i]) + for i, img in enumerate(images): + ax = fig.add_subplot(1, nplots, i+1) + ax.axis(axis) + ax.imshow(img) plt.show() +