-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtflite.py
25 lines (21 loc) · 1011 Bytes
/
tflite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from tensorflow import lite
import argparse
import logging
import os
from os.path import join, exists
logging.basicConfig(level=logging.INFO, format='%(asctime)s [INFO] %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('-model_epoch', type=int, required=True)
parser.add_argument('-exp_name', type=str, required=True)
parser.add_argument('-saved_dir', default='saved_models', type=str)
parser.add_argument('-tflite_dir', default='tflite_models', type=str)
args = parser.parse_args()
saved_path = join(join(args.saved_dir, args.exp_name), str(args.model_epoch))
logging.info('Creating converter from saved path %s...' % saved_path)
assert exists(saved_path)
converter = lite.TFLiteConverter.from_saved_model(saved_path)
tflite_model = converter.convert()
if not exists(args.tflite_dir):
os.makedirs(args.tflite_dir)
open(join(args.tflite_dir, 'converted_model_'+args.exp_name+'_epoch_'+str(args.model_epoch)+'.tflite'), 'wb').write(tflite_model)
logging.info('Converting successful!')