forked from netabecker/Stegastamp_pytorch_version
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdecode_image.py
79 lines (64 loc) · 2.31 KB
/
decode_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import glob
import bchlib
import numpy as np
from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torchvision import transforms
# set values for BCH
BCH_POLYNOMIAL = 137
BCH_BITS = 5
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str)
parser.add_argument('--image', type=str, default=None)
parser.add_argument('--images_dir', type=str, default=None)
parser.add_argument('--secret_size', type=int, default=100)
parser.add_argument('--cuda', type=bool, default=True)
args = parser.parse_args()
if args.image is not None:
files_list = [args.image]
elif args.images_dir is not None:
files_list = glob.glob(args.images_dir + '/*')
else:
print('Missing input image')
return
decoder = torch.load(args.model)
decoder.eval()
if args.cuda:
decoder = decoder.cuda()
bch = bchlib.BCH(prim_poly=BCH_POLYNOMIAL, t=BCH_BITS)
width = 400
height = 400
size = (width, height)
to_tensor = transforms.ToTensor()
with torch.no_grad():
for filename in files_list:
if 'hidden' not in filename:
continue
image = Image.open(filename).convert("RGB")
image = ImageOps.fit(image, size)
image = to_tensor(image).unsqueeze(0)
if args.cuda:
image = image.cuda()
secret = decoder(image)
if args.cuda:
secret = secret.cpu()
secret = np.array(secret[0])
secret = np.round(secret)
packet_binary = "".join([str(int(bit)) for bit in secret[:96]])
packet = bytes(int(packet_binary[i: i + 8], 2) for i in range(0, len(packet_binary), 8))
packet = bytearray(packet)
data, ecc = packet[:-bch.ecc_bytes], packet[-bch.ecc_bytes:]
bitflips = bch.decode(data, ecc)
if bitflips != -1:
try:
code = data.decode("utf-8")
print(filename, code)
continue
except:
continue
print(filename, 'Failed to decode')
if __name__ == "__main__":
main()