diff --git a/gnn4rcpsp/learn_schedules.py b/gnn4rcpsp/learn_schedules.py index e56ba8e..6ecd815 100644 --- a/gnn4rcpsp/learn_schedules.py +++ b/gnn4rcpsp/learn_schedules.py @@ -43,7 +43,6 @@ def check_solution(data): "makespan": [], } for data in data_batch.to_data_list(): - t2t, dur, r2t, rc, ref_makespan = ( data.t2t.view(len(data.dur), -1).data.cpu().detach().numpy(), data.dur.data.cpu().detach().numpy(), @@ -303,7 +302,6 @@ def hook(module, input, output, name=name): writer.add_scalar("loss", loss.item(), step) if COMPUTE_METRICS: - for batch_idx, data in enumerate(test_loader): data.to(device) out = model(data) @@ -356,7 +354,7 @@ def hook(module, input, output, name=name): best_sgs_makespan_ref = epoch_violations["sgs_makespan_ref"] shutil.copyfile( f"saved_models/ResTransformer-256-50000/model_{kfold_idx}_{epoch}.tch", - "best_model.tch", + "saved_models/ResTransformer-256-50000/best_model.tch", ) writer.flush() @@ -368,7 +366,6 @@ def hook(module, input, output, name=name): if __name__ == "__main__": - file_path = os.path.realpath(__file__) if not os.path.exists(os.path.join(os.path.dirname(file_path), "data_list.tch")): @@ -399,6 +396,14 @@ def hook(module, input, output, name=name): # train_list += random.sample(medium_list, int(0.5 * len(medium_list))) torch.save(train_list, "./train_list.tch") + # Only if we are in small-medium training mode + if os.path.exists(os.path.join(os.path.dirname(file_path), "validation_list.tch")): + validation_list = torch.load("validation_list.tch") + else: + big_inst_list = list(set(range(len(data_list))) - set(train_list)) + validation_list = random.sample(big_inst_list, int(0.5 * len(big_inst_list))) + torch.save(validation_list, "./validation_list.tch") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") Net = ResTransformer # Net = ResGINE @@ -419,4 +424,5 @@ def hook(module, input, output, name=name): optimizer=optimizer, device=device, writer=writer, + validation_list=validation_list, )