Skip to content

Commit

Permalink
Add curategpt scoring (#55)
Browse files Browse the repository at this point in the history
* started brainstorming

* added test_curate_script that is a prototype of what to insert in create_standardised_results in order for it to substitute ontogpts grounding

* working version of curategpt implemented in main branch, still 3 relevant TODOs in extended scoring with manual paths...

* update poetry files and add missing mkdocs.yml

* polished curategpt, api key setting will be done by hand and will be explained in documentation. Tested a few times. Also on the fly removed num_ppkt related bug

* got rid of local curategpt dependency and updated lock etc

* Update poetry.lock

Edited error introduced during manual merge

* solve the issue by regenerating lock from merged pyproject.toml

* solve the issue by regenerating lock and adding lost package putils

* Update qc.yml

Maybe this may save up some space and deal with 

https://github.com/monarch-initiative/pheval.llm/actions/runs/11918215693/job/33215129156?pr=55#step:6:71

* Update qc.yml

Maybe poetry install is not needed at all, since tox recreates its own environment...?

* Update qc.yml

Forgot to install tox, trying now with pip, alternatively try later with poetry, but not doing a full install

* Update qc.yml

add missing coverage package

* Update qc.yml

Try having tox do everything
  • Loading branch information
leokim-l authored Nov 20, 2024
1 parent ca75c84 commit bb6d78d
Show file tree
Hide file tree
Showing 11 changed files with 4,057 additions and 3,412 deletions.
13 changes: 3 additions & 10 deletions .github/workflows/qc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,7 @@ jobs:
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Install dependencies
run: poetry install --no-interaction --no-root
- name: Install tox
run: pip install tox
- name: Run Doctests
run: poetry run tox -e doctest
- name: Generate coverage results
run: |
poetry run pip install -U pytest
poetry run coverage run -p -m pytest tests/
poetry run coverage combine
poetry run coverage xml
poetry run coverage report -m
run: poetry run tox
1,551 changes: 0 additions & 1,551 deletions notebooks/process_gpt_4o_and_plot.ipynb

This file was deleted.

1,333 changes: 0 additions & 1,333 deletions notebooks/process_gpt_o1_and_plot.ipynb

This file was deleted.

4,271 changes: 3,776 additions & 495 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ python = "^3.10"
pheval = "^0.3.2"
setuptools = "^69.5.1"
shelved-cache = "^0.3.1"
curategpt = "^0.2.2"
psutil = "^6.1.0"

[tool.poetry.plugins."pheval.plugins"]
template = "malco.runner:MalcoRunner"
Expand Down
53 changes: 53 additions & 0 deletions src/malco/analysis/test_curate_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import yaml
from pathlib import Path
from typing import List
from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo
from oaklib import get_adapter


def read_raw_result_yaml(raw_result_path: Path) -> List[dict]:
"""
Read the raw result file.
Args:
raw_result_path(Path): Path to the raw result file.
Returns:
dict: Contents of the raw result file.
"""
with open(raw_result_path, 'r') as raw_result:
return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list


annotator = get_adapter("sqlite:obo:mondo")
some_yaml_res = Path("/Users/leonardo/git/malco/out_openAI_models/raw_results/multimodel/gpt-4/results.yaml")

data = []

if some_yaml_res.is_file():
all_results = read_raw_result_yaml(some_yaml_res)
j = 0
for this_result in all_results:
extracted_object = this_result.get("extracted_object")
if extracted_object: # Necessary because this is how I keep track of multiple runs
ontogpt_text = this_result.get("input_text")
# its a single string, should be parseable through curategpt
cleaned_text = clean_service_answer(ontogpt_text)
assert cleaned_text != "", "Cleaning failed: the cleaned text is empty."
result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False)

label = extracted_object.get('label') # pubmed id
# terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream
terms = [i[1][0][0] for i in result]
#terms = extracted_object.get('terms') # list of strings, the mondo id or description
if terms:
# Note, the if allows for rerunning ppkts that failed due to connection issues
# We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field
num_terms = len(terms)
score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank
rank_list = [ i+1 for i in range(num_terms)]
for term, scr, rank in zip(terms, score, rank_list):
data.append({'label': label, 'term': term, 'score': scr, 'rank': rank})
if j>20:
break
j += 1
173 changes: 173 additions & 0 deletions src/malco/post_process/extended_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import re
import os
from oaklib.interfaces.text_annotator_interface import TextAnnotationConfiguration
from oaklib.interfaces.text_annotator_interface import TextAnnotatorInterface
from curategpt.store import get_store
from typing import List, Tuple


# Compile a regex pattern to detect lines starting with "Differential Diagnosis:"
dd_re = re.compile(r"^[^A-z]*Differential Diagnosis")

# Function to clean and remove "Differential Diagnosis" header if present
def clean_service_answer(answer: str) -> str:
"""Remove the 'Differential Diagnosis' header if present, and clean the first line."""
lines = answer.split('\n')
# Filter out any line that starts with "Differential Diagnosis:"
cleaned_lines = [line for line in lines if not dd_re.match(line)]
return '\n'.join(cleaned_lines)

# Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces
def clean_diagnosis_line(line: str) -> str:
"""Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis."""
line = re.sub(r'^\**\d+\.\s*', '', line) # Remove leading numbers and periods
line = line.strip('*') # Remove asterisks around the text
return line.strip() # Strip any remaining spaces

# Split a diagnosis into its main name and synonym if present
def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]:
"""Split the diagnosis into main name and synonym (if present in parentheses)."""
match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis)
if match:
main_diagnosis, synonym = match.groups()
return main_diagnosis.strip(), synonym.strip()
return diagnosis, None # Return the original diagnosis if no synonym is found

def perform_curategpt_grounding(
diagnosis: str,
path: str,
collection: str,
database_type: str = "chromadb",
limit: int = 1,
relevance_factor: float = 0.23,
verbose: bool = False
) -> List[Tuple[str, str]]:
"""
Use curategpt to perform grounding for a given diagnosis when initial attempts fail.
Parameters:
- diagnosis: The diagnosis text to ground.
- path: The path to the database. You'll need to create an index of Mondo using curategpt in this db
- collection: The collection to search within curategpt. Name of mondo collection in the db
NB: You can make this collection by running curategpt thusly:
`curategpt ontology index --index-fields label,definition,relationships -p stagedb -c ont_mondo -m openai: sqlite:obo:mondo`
- database_type: The type of database used for grounding (e.g., chromadb, duckdb).
- limit: The number of search results to return.
- relevance_factor: The distance threshold for relevance filtering.
- verbose: Whether to print verbose output for debugging.
Returns:
- List of tuples: [(Mondo ID, Label), ...]
"""
# Initialize the database store
db = get_store(database_type, path)

# Perform the search using the provided diagnosis
results = db.search(diagnosis, collection=collection)

# Filter results based on relevance factor (distance)
if relevance_factor is not None:
results = [(obj, distance, _meta) for obj, distance, _meta in results if distance <= relevance_factor]

# Limit the results to the specified number (limit)
limited_results = results[:limit]

# Extract Mondo IDs and labels
pred_ids = []
pred_labels = []

for obj, distance, _meta in limited_results:
disease_mondo_id = obj.get("original_id") # Use the 'original_id' field for Mondo ID
disease_label = obj.get("label")

if disease_mondo_id and disease_label:
pred_ids.append(disease_mondo_id)
pred_labels.append(disease_label)

# Return as a list of tuples (Mondo ID, Label)
if len(pred_ids) == 0:
if verbose:
print(f"No grounded IDs found for {diagnosis}")
return [('N/A', 'No grounding found')]

return list(zip(pred_ids, pred_labels))


# Perform grounding on the text to MONDO ontology and return the result
def perform_oak_grounding(
annotator: TextAnnotatorInterface,
diagnosis: str,
exact_match: bool = True,
verbose: bool = False,
include_list: List[str] = ["MONDO:"],
) -> List[Tuple[str, str]]:
"""
Perform grounding for a diagnosis. The 'exact_match' flag controls whether exact or inexact
(partial) matching is used. Filter results to include only CURIEs that match the 'include_list',
and exclude results that match the 'exclude_list'.
Remove redundant groundings from the result.
"""
config = TextAnnotationConfiguration(matches_whole_text=exact_match)
annotations = list(annotator.annotate_text(diagnosis, configuration=config))

# Filter and remove duplicates, while excluding unwanted general terms
filtered_annotations = list(
{
(ann.object_id, ann.object_label)
for ann in annotations
if any(ann.object_id.startswith(prefix) for prefix in include_list)
}
)

if filtered_annotations:
return filtered_annotations
else:
match_type = "exact" if exact_match else "inexact"
if verbose:
print(f"No {match_type} grounded IDs found for: {diagnosis}")
pass
return [('N/A', 'No grounding found')]

# Now, integrate curategpt into your ground_diagnosis_text_to_mondo function
def ground_diagnosis_text_to_mondo(
annotator: TextAnnotatorInterface,
differential_diagnosis: str,
verbose: bool = False,
include_list: List[str] = ["MONDO:"],
use_ontogpt_grounding: bool = True,
curategpt_path: str = "../curategpt/stagedb/",
curategpt_collection: str = "ont_mondo",
curategpt_database_type: str = "chromadb"
) -> List[Tuple[str, List[Tuple[str, str]]]]:
results = []

# Split the input into lines and process each one
for line in differential_diagnosis.splitlines():
clean_line = clean_diagnosis_line(line)

# Skip header lines like "**Differential diagnosis:**"
if not clean_line or "Differential diagnosis" in clean_line.lower():
continue

# Try grounding the full line first (exact match)
grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list)

# Try grounding with curategpt if no grounding is found
if use_ontogpt_grounding and grounded == [('N/A', 'No grounding found')]:
grounded = perform_curategpt_grounding(
diagnosis=clean_line,
path=curategpt_path,
collection=curategpt_collection,
database_type=curategpt_database_type,
verbose=verbose
)

# If still no grounding is found, log the final failure
if grounded == [('N/A', 'No grounding found')]:
if verbose:
print(f"Final grounding failed for: {clean_line}")

# Append the grounded results (even if no grounding was found)
results.append((clean_line, grounded))

return results
4 changes: 4 additions & 0 deletions src/malco/post_process/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file,
plot_dir = data_dir.parents[0] / "plots"
plot_dir.mkdir(exist_ok=True)

# For plot filenam labeling use lowest number of ppkt available for all models/languages etc.
num_ppkt = min(num_ppkt.values())

if comparing=="model":
name_string = str(len(models))
else:
Expand Down Expand Up @@ -39,6 +42,7 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file,

plt.xlabel("Number of Ranks in")
plt.ylabel("Percentage of Cases")
plt.ylim([0.0, 1.0])
plt.title("Rank Comparison for Differential Diagnosis")
plt.legend(title=comparing)
plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png")
Expand Down
18 changes: 13 additions & 5 deletions src/malco/post_process/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,30 @@ def post_process(self) -> None:
output_dir = self.output_dir
langs = self.languages
models = self.models

curategpt = True

if self.modality == "several_languages":
for lang in langs:
raw_results_lang = raw_results_dir / "multilingual" / lang
output_lang = output_dir / "multilingual" / lang
raw_results_lang.mkdir(exist_ok=True, parents=True)
output_lang.mkdir(exist_ok=True, parents=True)

create_standardised_results(raw_results_dir=raw_results_lang,
output_dir=output_lang, output_file_name="results.tsv")
create_standardised_results(curategpt,
raw_results_dir=raw_results_lang,
output_dir=output_lang,
output_file_name="results.tsv",
)

elif self.modality == "several_models":
for model in models:
raw_results_model = raw_results_dir / "multimodel" / model
output_model = output_dir / "multimodel" / model
raw_results_model.mkdir(exist_ok=True, parents=True)
output_model.mkdir(exist_ok=True, parents=True)

create_standardised_results(raw_results_dir=raw_results_model,
output_dir=output_model, output_file_name="results.tsv")
create_standardised_results(curategpt,
raw_results_dir=raw_results_model,
output_dir=output_model,
output_file_name="results.tsv",
)
23 changes: 20 additions & 3 deletions src/malco/post_process/post_process_results_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from pheval.utils.file_utils import all_files
from pheval.utils.phenopacket_utils import GeneIdentifierUpdater, create_hgnc_dict
from malco.post_process.df_save_util import safe_save_tsv

from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo
from oaklib import get_adapter



def read_raw_result_yaml(raw_result_path: Path) -> List[dict]:
Expand All @@ -26,9 +28,16 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]:
return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list


def create_standardised_results(raw_results_dir: Path, output_dir: Path,
output_file_name: str) -> pd.DataFrame:
def create_standardised_results(curategpt: bool,
raw_results_dir: Path,
output_dir: Path,
output_file_name: str
) -> pd.DataFrame:

data = []
if curategpt:
annotator = get_adapter("sqlite:obo:mondo")

for raw_result_path in raw_results_dir.iterdir():
if raw_result_path.is_file():
# Cannot have further files in raw_result_path!
Expand All @@ -39,6 +48,14 @@ def create_standardised_results(raw_results_dir: Path, output_dir: Path,
if extracted_object:
label = extracted_object.get('label')
terms = extracted_object.get('terms')
if curategpt and terms:
ontogpt_text = this_result.get("input_text")
# its a single string, should be parseable through curategpt
cleaned_text = clean_service_answer(ontogpt_text)
assert cleaned_text != "", "Cleaning failed: the cleaned text is empty."
result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False)
# terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream
terms = [i[1][0][0] for i in result] # MONDO_ID
if terms:
# Note, the if allows for rerunning ppkts that failed due to connection issues
# We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field
Expand Down
Loading

0 comments on commit bb6d78d

Please sign in to comment.