-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
31 lines (24 loc) · 1.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
from DataGenerator import DataGenerator
from Network import Network
import numpy as np
from Parameters import Parameters
import tensorflow as tf
from utils import load_settings_yaml
tf.compat.v1.disable_eager_execution()
# Define ArgumentParser
parser = argparse.ArgumentParser()
parser.add_argument("--run", help="Location/path of the run.yaml file. This is usually structured as a path.", default="runs/run.yaml", required=False)
args = parser.parse_args()
# Unpack args
yaml_path = args.run
# Load all settings from .yaml file
settings_yaml = load_settings_yaml(yaml_path) # Returns a dictionary object.
params = Parameters(settings_yaml, yaml_path)
params.data_type = np.float32 if params.data_type == "np.float32" else np.float32 # This must be done here, due to the json, not accepting this kind of if statement in the parameter class.
# Create Custom Data Generator
datagen = DataGenerator(params)
# Create Neural Network
network = Network(params, datagen, training=True) # The network needs to know hyper-paramters from params, and needs to know how to generate data with a datagenerator object.
# Train the network
network.train()