Skip to content

Commit

Permalink
Histo segmentation training and inference working
Browse files Browse the repository at this point in the history
  • Loading branch information
szmazurek committed Jan 16, 2025
1 parent 739f05d commit 7611b3a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
16 changes: 12 additions & 4 deletions GANDLF/models/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,16 @@ def _process_prediction_logit_for_row_writing(
"""
return prediction_logit.cpu().max().item() / scaling_factor

def _print_currently_processed_subject(self, subject: torchio.Subject):
print("== Current subject:", subject["subject_id"], flush=True)
def _print_currently_processed_subject(self, subject):
if isinstance(subject, torchio.Subject):
subject_id = subject["subject_id"]
elif isinstance(subject, tuple):
# ugly corner histology inference handling, when incoming batch is
# a row from dataframe, not a torchio.Subject. This should be solved
# via some kind of polymorphism in the future
subject_data = subject[1]
subject_id = subject_data[self.params["headers"]["subjectIDHeader"]]
print("== Current subject:", subject_id, flush=True)

def _initialize_subject_dict_nontraining_mode(self, subject: torchio.Subject):
"""
Expand Down Expand Up @@ -1753,7 +1761,7 @@ def _histopathology_inference_step(self, row_index_tuple):
self.params["model"]["num_channels"],
patch_size_updated_after_transforms,
)
count_map, probabilities_map = self._iterate_over_hisopathology_loader(
count_map, probabilities_map = self._iterate_over_histopathology_loader(
histopathology_dataloader,
count_map,
probabilities_map,
Expand Down Expand Up @@ -1783,7 +1791,7 @@ def _histopathology_inference_step(self, row_index_tuple):
self.rows_to_write, inference_results_save_dir_for_subject
)

def _iterate_over_hisopathology_loader(
def _iterate_over_histopathology_loader(
self,
histopathology_dataloader,
count_map,
Expand Down
12 changes: 4 additions & 8 deletions testing/test_lightning_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def test_port_model_inference_classification_histology_2d(device):
parameters["patch_size"] = 128
file_config_temp = write_temp_config_path(parameters)
parameters = ConfigManager(file_config_temp, version_check_flag=False)
os.remove(file_config_temp)
parameters["model"]["dimension"] = 2
# read and parse csv
training_data, parameters["headers"] = parseTrainingCSV(file_for_Training)
Expand Down Expand Up @@ -860,6 +861,7 @@ def test_port_model_inference_segmentation_histology_2d():
parameters["modality"] = "histo"
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
parameters["penalty_weights"] = [1, 1]
parameters["model"]["amp"] = True
parameters["model"]["num_channels"] = 3
parameters = populate_header_in_parameters(parameters, parameters["headers"])
Expand Down Expand Up @@ -912,11 +914,5 @@ def test_port_model_inference_segmentation_histology_2d():
TEST_DATA_DIRPATH + "/train_2d_histo_segmentation.csv"
)
inference_data.drop(index=inference_data.index[-1], axis=0, inplace=True)
inference_dataloader = torch.utils.data.DataLoader(
ImagesFromDataFrame(
inference_data, parameters, train=False, loader_type="testing"
),
batch_size=parameters["batch_size"],
shuffle=False,
)
trainer.predict(module, inference_dataloader)

trainer.predict(module, inference_data.iterrows())

0 comments on commit 7611b3a

Please sign in to comment.