Skip to content

Commit

Permalink
Infer with validation set
Browse files Browse the repository at this point in the history
  • Loading branch information
fteicht committed Mar 15, 2023
1 parent 1a5b1d0 commit 8c28e41
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions gnn4rcpsp/infer_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def resource_constraint(r, t, r2t, dur, x, grad):
# TODO: vectorize?
data_list = data_batch.to_data_list()
for data in data_list:

t2t, dur, r2t, rc, ref_makespan = (
data.t2t.view(len(data.dur), -1).data.cpu().detach().numpy(),
data.dur.data.cpu().detach().numpy(),
Expand Down Expand Up @@ -277,7 +276,6 @@ def make_feasible_cpsat(data):
# TODO: vectorize?
data_list = data_batch.to_data_list()
for data in data_list:

t2t, dur, r2t, rc, ref_makespan = (
data.t2t.view(len(data.dur), -1).data.cpu().detach().numpy(),
data.dur.data.cpu().detach().numpy(),
Expand Down Expand Up @@ -408,7 +406,6 @@ def make_feasible_cpsat(data):
raise RuntimeError("Invalid CPSAT model.")

for k in cpsat_result:

cpsat_result[k] = np.mean([r for r in cpsat_result[k] if r > -1])

if len(data_list) == 1: # batch sizes of 1 so 1 batch == 1 instance
Expand Down Expand Up @@ -483,9 +480,9 @@ def make_feasible_sgs(data, just_dummy_version: bool = False):
sol = RCPSPSolution(problem=do_model, rcpsp_permutation=perm)
makespan = sol.get_max_end_time()
feasible_solution = [sol.get_start_time(t) for t in do_model.tasks_list]
print("Makespan GNN ", max(xorig + dur))
print("Ref makespan ", max(ref_makespan))
print("Obtained by post pro gnn ", max(feasible_solution))
# print("Makespan GNN ", max(xorig + dur))
# print("Ref makespan ", max(ref_makespan))
# print("Obtained by post pro gnn ", max(feasible_solution))
cpsat_result["feasibility_timing"].append(perf_counter() - cur_time)
cpsat_result["feasibility_rel_makespan_cor"].append(makespan / max(xorig + dur))
cpsat_result["feasibility_rel_makespan_ref"].append(
Expand Down Expand Up @@ -609,8 +606,11 @@ def test(

def script_gpd():
data_list = torch.load("../torch_data/data_list_sm.tch")
train_list = torch.load("../torch_data/train_list_all.tch")
test_list = list(set(range(len(data_list))) - set(train_list))
train_list = torch.load("../torch_data/train_list_sm.tch")
validation_list = torch.load("../torch_data/validation_list.tch")
test_list = list(
(set(range(len(data_list))) - set(train_list)) - set(validation_list)
)
test_loader = DataLoader(
[data_list[d] for d in test_list],
batch_size=1,
Expand All @@ -623,7 +623,7 @@ def script_gpd():
# model.load_state_dict(torch.load('saved_models/ResTransformer-256-50000/model_49900.tch'))
model.load_state_dict(
torch.load(
"../torch_data/model_ResTransformer_all_256_50000.tch",
"saved_models/ResTransformer-256-50000/best_model.tch",
map_location=device,
)
)
Expand Down

0 comments on commit 8c28e41

Please sign in to comment.