Skip to content

Commit

Permalink
Save best model at the right place
Browse files Browse the repository at this point in the history
  • Loading branch information
fteicht committed Mar 15, 2023
1 parent 4e5f8fd commit 4c27aec
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions gnn4rcpsp/learn_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def check_solution(data):
"makespan": [],
}
for data in data_batch.to_data_list():

t2t, dur, r2t, rc, ref_makespan = (
data.t2t.view(len(data.dur), -1).data.cpu().detach().numpy(),
data.dur.data.cpu().detach().numpy(),
Expand Down Expand Up @@ -303,7 +302,6 @@ def hook(module, input, output, name=name):
writer.add_scalar("loss", loss.item(), step)

if COMPUTE_METRICS:

for batch_idx, data in enumerate(test_loader):
data.to(device)
out = model(data)
Expand Down Expand Up @@ -356,7 +354,7 @@ def hook(module, input, output, name=name):
best_sgs_makespan_ref = epoch_violations["sgs_makespan_ref"]
shutil.copyfile(
f"saved_models/ResTransformer-256-50000/model_{kfold_idx}_{epoch}.tch",
"best_model.tch",
"saved_models/ResTransformer-256-50000/best_model.tch",
)

writer.flush()
Expand All @@ -368,7 +366,6 @@ def hook(module, input, output, name=name):


if __name__ == "__main__":

file_path = os.path.realpath(__file__)

if not os.path.exists(os.path.join(os.path.dirname(file_path), "data_list.tch")):
Expand Down Expand Up @@ -399,6 +396,14 @@ def hook(module, input, output, name=name):
# train_list += random.sample(medium_list, int(0.5 * len(medium_list)))
torch.save(train_list, "./train_list.tch")

# Only if we are in small-medium training mode
if os.path.exists(os.path.join(os.path.dirname(file_path), "validation_list.tch")):
validation_list = torch.load("validation_list.tch")
else:
big_inst_list = list(set(range(len(data_list))) - set(train_list))
validation_list = random.sample(big_inst_list, int(0.5 * len(big_inst_list)))
torch.save(validation_list, "./validation_list.tch")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Net = ResTransformer
# Net = ResGINE
Expand All @@ -419,4 +424,5 @@ def hook(module, input, output, name=name):
optimizer=optimizer,
device=device,
writer=writer,
validation_list=validation_list,
)

0 comments on commit 4c27aec

Please sign in to comment.