Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add curategpt scoring #55

Merged
merged 15 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions .github/workflows/qc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,7 @@ jobs:
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Install dependencies
run: poetry install --no-interaction --no-root
- name: Install tox
run: pip install tox
- name: Run Doctests
run: poetry run tox -e doctest
- name: Generate coverage results
run: |
poetry run pip install -U pytest
poetry run coverage run -p -m pytest tests/
poetry run coverage combine
poetry run coverage xml
poetry run coverage report -m
run: poetry run tox
1,551 changes: 0 additions & 1,551 deletions notebooks/process_gpt_4o_and_plot.ipynb

This file was deleted.

1,333 changes: 0 additions & 1,333 deletions notebooks/process_gpt_o1_and_plot.ipynb

This file was deleted.

4,271 changes: 3,776 additions & 495 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ python = "^3.10"
pheval = "^0.3.2"
setuptools = "^69.5.1"
shelved-cache = "^0.3.1"
curategpt = "^0.2.2"
psutil = "^6.1.0"

[tool.poetry.plugins."pheval.plugins"]
template = "malco.runner:MalcoRunner"
Expand Down
53 changes: 53 additions & 0 deletions src/malco/analysis/test_curate_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import yaml
from pathlib import Path
from typing import List
from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo
from oaklib import get_adapter


def read_raw_result_yaml(raw_result_path: Path) -> List[dict]:
"""
Read the raw result file.

Args:
raw_result_path(Path): Path to the raw result file.

Returns:
dict: Contents of the raw result file.
"""
with open(raw_result_path, 'r') as raw_result:
return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list


annotator = get_adapter("sqlite:obo:mondo")
some_yaml_res = Path("/Users/leonardo/git/malco/out_openAI_models/raw_results/multimodel/gpt-4/results.yaml")

data = []

if some_yaml_res.is_file():
all_results = read_raw_result_yaml(some_yaml_res)
j = 0
for this_result in all_results:
extracted_object = this_result.get("extracted_object")
if extracted_object: # Necessary because this is how I keep track of multiple runs
ontogpt_text = this_result.get("input_text")
# its a single string, should be parseable through curategpt
cleaned_text = clean_service_answer(ontogpt_text)
assert cleaned_text != "", "Cleaning failed: the cleaned text is empty."
result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False)

label = extracted_object.get('label') # pubmed id
# terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream
terms = [i[1][0][0] for i in result]
#terms = extracted_object.get('terms') # list of strings, the mondo id or description
if terms:
# Note, the if allows for rerunning ppkts that failed due to connection issues
# We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field
num_terms = len(terms)
score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank
rank_list = [ i+1 for i in range(num_terms)]
for term, scr, rank in zip(terms, score, rank_list):
data.append({'label': label, 'term': term, 'score': scr, 'rank': rank})
if j>20:
break
j += 1
173 changes: 173 additions & 0 deletions src/malco/post_process/extended_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import re
import os
from oaklib.interfaces.text_annotator_interface import TextAnnotationConfiguration
from oaklib.interfaces.text_annotator_interface import TextAnnotatorInterface
from curategpt.store import get_store
from typing import List, Tuple


# Compile a regex pattern to detect lines starting with "Differential Diagnosis:"
dd_re = re.compile(r"^[^A-z]*Differential Diagnosis")

# Function to clean and remove "Differential Diagnosis" header if present
def clean_service_answer(answer: str) -> str:
"""Remove the 'Differential Diagnosis' header if present, and clean the first line."""
lines = answer.split('\n')
# Filter out any line that starts with "Differential Diagnosis:"
cleaned_lines = [line for line in lines if not dd_re.match(line)]
return '\n'.join(cleaned_lines)

# Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces
def clean_diagnosis_line(line: str) -> str:
"""Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis."""
line = re.sub(r'^\**\d+\.\s*', '', line) # Remove leading numbers and periods
line = line.strip('*') # Remove asterisks around the text
return line.strip() # Strip any remaining spaces

# Split a diagnosis into its main name and synonym if present
def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]:
"""Split the diagnosis into main name and synonym (if present in parentheses)."""
match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis)
if match:
main_diagnosis, synonym = match.groups()
return main_diagnosis.strip(), synonym.strip()
return diagnosis, None # Return the original diagnosis if no synonym is found

def perform_curategpt_grounding(
diagnosis: str,
path: str,
collection: str,
database_type: str = "chromadb",
limit: int = 1,
relevance_factor: float = 0.23,
verbose: bool = False
) -> List[Tuple[str, str]]:
"""
Use curategpt to perform grounding for a given diagnosis when initial attempts fail.

Parameters:
- diagnosis: The diagnosis text to ground.
- path: The path to the database. You'll need to create an index of Mondo using curategpt in this db
- collection: The collection to search within curategpt. Name of mondo collection in the db
NB: You can make this collection by running curategpt thusly:
`curategpt ontology index --index-fields label,definition,relationships -p stagedb -c ont_mondo -m openai: sqlite:obo:mondo`
- database_type: The type of database used for grounding (e.g., chromadb, duckdb).
- limit: The number of search results to return.
- relevance_factor: The distance threshold for relevance filtering.
- verbose: Whether to print verbose output for debugging.

Returns:
- List of tuples: [(Mondo ID, Label), ...]
"""
# Initialize the database store
db = get_store(database_type, path)

# Perform the search using the provided diagnosis
results = db.search(diagnosis, collection=collection)

# Filter results based on relevance factor (distance)
if relevance_factor is not None:
results = [(obj, distance, _meta) for obj, distance, _meta in results if distance <= relevance_factor]

# Limit the results to the specified number (limit)
limited_results = results[:limit]

# Extract Mondo IDs and labels
pred_ids = []
pred_labels = []

for obj, distance, _meta in limited_results:
disease_mondo_id = obj.get("original_id") # Use the 'original_id' field for Mondo ID
disease_label = obj.get("label")

if disease_mondo_id and disease_label:
pred_ids.append(disease_mondo_id)
pred_labels.append(disease_label)

# Return as a list of tuples (Mondo ID, Label)
if len(pred_ids) == 0:
if verbose:
print(f"No grounded IDs found for {diagnosis}")
return [('N/A', 'No grounding found')]

return list(zip(pred_ids, pred_labels))


# Perform grounding on the text to MONDO ontology and return the result
def perform_oak_grounding(
annotator: TextAnnotatorInterface,
diagnosis: str,
exact_match: bool = True,
verbose: bool = False,
include_list: List[str] = ["MONDO:"],
) -> List[Tuple[str, str]]:
"""
Perform grounding for a diagnosis. The 'exact_match' flag controls whether exact or inexact
(partial) matching is used. Filter results to include only CURIEs that match the 'include_list',
and exclude results that match the 'exclude_list'.
Remove redundant groundings from the result.
"""
config = TextAnnotationConfiguration(matches_whole_text=exact_match)
annotations = list(annotator.annotate_text(diagnosis, configuration=config))

# Filter and remove duplicates, while excluding unwanted general terms
filtered_annotations = list(
{
(ann.object_id, ann.object_label)
for ann in annotations
if any(ann.object_id.startswith(prefix) for prefix in include_list)
}
)

if filtered_annotations:
return filtered_annotations
else:
match_type = "exact" if exact_match else "inexact"
if verbose:
print(f"No {match_type} grounded IDs found for: {diagnosis}")
pass
return [('N/A', 'No grounding found')]

# Now, integrate curategpt into your ground_diagnosis_text_to_mondo function
def ground_diagnosis_text_to_mondo(
annotator: TextAnnotatorInterface,
differential_diagnosis: str,
verbose: bool = False,
include_list: List[str] = ["MONDO:"],
use_ontogpt_grounding: bool = True,
curategpt_path: str = "../curategpt/stagedb/",
curategpt_collection: str = "ont_mondo",
curategpt_database_type: str = "chromadb"
) -> List[Tuple[str, List[Tuple[str, str]]]]:
results = []

# Split the input into lines and process each one
for line in differential_diagnosis.splitlines():
clean_line = clean_diagnosis_line(line)

# Skip header lines like "**Differential diagnosis:**"
if not clean_line or "Differential diagnosis" in clean_line.lower():
continue

# Try grounding the full line first (exact match)
grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list)

# Try grounding with curategpt if no grounding is found
if use_ontogpt_grounding and grounded == [('N/A', 'No grounding found')]:
grounded = perform_curategpt_grounding(
diagnosis=clean_line,
path=curategpt_path,
collection=curategpt_collection,
database_type=curategpt_database_type,
verbose=verbose
)

# If still no grounding is found, log the final failure
if grounded == [('N/A', 'No grounding found')]:
if verbose:
print(f"Final grounding failed for: {clean_line}")

# Append the grounded results (even if no grounding was found)
results.append((clean_line, grounded))

return results
4 changes: 4 additions & 0 deletions src/malco/post_process/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file,
plot_dir = data_dir.parents[0] / "plots"
plot_dir.mkdir(exist_ok=True)

# For plot filenam labeling use lowest number of ppkt available for all models/languages etc.
num_ppkt = min(num_ppkt.values())

if comparing=="model":
name_string = str(len(models))
else:
Expand Down Expand Up @@ -39,6 +42,7 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file,

plt.xlabel("Number of Ranks in")
plt.ylabel("Percentage of Cases")
plt.ylim([0.0, 1.0])
plt.title("Rank Comparison for Differential Diagnosis")
plt.legend(title=comparing)
plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png")
Expand Down
18 changes: 13 additions & 5 deletions src/malco/post_process/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,30 @@ def post_process(self) -> None:
output_dir = self.output_dir
langs = self.languages
models = self.models

curategpt = True

if self.modality == "several_languages":
for lang in langs:
raw_results_lang = raw_results_dir / "multilingual" / lang
output_lang = output_dir / "multilingual" / lang
raw_results_lang.mkdir(exist_ok=True, parents=True)
output_lang.mkdir(exist_ok=True, parents=True)

create_standardised_results(raw_results_dir=raw_results_lang,
output_dir=output_lang, output_file_name="results.tsv")
create_standardised_results(curategpt,
raw_results_dir=raw_results_lang,
output_dir=output_lang,
output_file_name="results.tsv",
)

elif self.modality == "several_models":
for model in models:
raw_results_model = raw_results_dir / "multimodel" / model
output_model = output_dir / "multimodel" / model
raw_results_model.mkdir(exist_ok=True, parents=True)
output_model.mkdir(exist_ok=True, parents=True)

create_standardised_results(raw_results_dir=raw_results_model,
output_dir=output_model, output_file_name="results.tsv")
create_standardised_results(curategpt,
raw_results_dir=raw_results_model,
output_dir=output_model,
output_file_name="results.tsv",
)
23 changes: 20 additions & 3 deletions src/malco/post_process/post_process_results_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from pheval.utils.file_utils import all_files
from pheval.utils.phenopacket_utils import GeneIdentifierUpdater, create_hgnc_dict
from malco.post_process.df_save_util import safe_save_tsv

from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo
from oaklib import get_adapter



def read_raw_result_yaml(raw_result_path: Path) -> List[dict]:
Expand All @@ -26,9 +28,16 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]:
return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list


def create_standardised_results(raw_results_dir: Path, output_dir: Path,
output_file_name: str) -> pd.DataFrame:
def create_standardised_results(curategpt: bool,
raw_results_dir: Path,
output_dir: Path,
output_file_name: str
) -> pd.DataFrame:

data = []
if curategpt:
annotator = get_adapter("sqlite:obo:mondo")

for raw_result_path in raw_results_dir.iterdir():
if raw_result_path.is_file():
# Cannot have further files in raw_result_path!
Expand All @@ -39,6 +48,14 @@ def create_standardised_results(raw_results_dir: Path, output_dir: Path,
if extracted_object:
label = extracted_object.get('label')
terms = extracted_object.get('terms')
if curategpt and terms:
ontogpt_text = this_result.get("input_text")
# its a single string, should be parseable through curategpt
cleaned_text = clean_service_answer(ontogpt_text)
assert cleaned_text != "", "Cleaning failed: the cleaned text is empty."
result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False)
# terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream
terms = [i[1][0][0] for i in result] # MONDO_ID
if terms:
# Note, the if allows for rerunning ppkts that failed due to connection issues
# We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field
Expand Down
Loading
Loading