-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcompute_stats_for_train.py
71 lines (57 loc) · 1.98 KB
/
compute_stats_for_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import json
import argparse
import numpy as np
from skimage.io import imread
from logger_utils import write_dict_to_json
def compute_stats(FLAGS):
list_images = sorted(os.listdir(FLAGS.dir_images))
num_images = len(list_images)
print(f"Num images: {num_images}")
print("Computing statistics for all the training images in the dataset")
all_means = np.array([])
all_stds = np.array([])
for idx in range(num_images):
image = imread(os.path.join(FLAGS.dir_images, list_images[idx]))
image = image / 255.0
"""
if idx == 10:
break
"""
all_means = np.append(all_means, np.mean(image[:, :, 0]))
all_stds = np.append(all_stds, np.std(image[:, :, 0]))
mean_of_images = np.mean(all_means)
std_of_images = np.sqrt(
all_stds.shape[0] * np.sum(np.square(all_stds)) / (all_stds.shape[0] - 1) ** 2
)
print(f"mean: {mean_of_images:.4f}, std: {std_of_images:.4f}")
dict_stats = {}
dict_stats["mean"] = round(mean_of_images, 4)
dict_stats["std"] = round(std_of_images, 4)
write_dict_to_json(FLAGS.file_json, dict_stats)
print(f"Training image statistics saved in {FLAGS.file_json}")
print("Completed computing statistics for all the training images in the dataset")
return
def main():
file_json = "image_stats.json"
dir_images = "/home/abhishek/Desktop/RUG/htsm_masterwork/oil-spill-detection-dataset/train/images/"
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--file_json",
default=file_json,
type=str,
help="full path of json file to be saved",
)
parser.add_argument(
"--dir_images",
default=dir_images,
type=str,
help="full path of directory containing training images",
)
FLAGS, unparsed = parser.parse_known_args()
compute_stats(FLAGS)
return
if __name__ == "__main__":
main()