-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* started brainstorming * added test_curate_script that is a prototype of what to insert in create_standardised_results in order for it to substitute ontogpts grounding * working version of curategpt implemented in main branch, still 3 relevant TODOs in extended scoring with manual paths... * update poetry files and add missing mkdocs.yml * polished curategpt, api key setting will be done by hand and will be explained in documentation. Tested a few times. Also on the fly removed num_ppkt related bug * got rid of local curategpt dependency and updated lock etc * Update poetry.lock Edited error introduced during manual merge * solve the issue by regenerating lock from merged pyproject.toml * solve the issue by regenerating lock and adding lost package putils * Update qc.yml Maybe this may save up some space and deal with https://github.com/monarch-initiative/pheval.llm/actions/runs/11918215693/job/33215129156?pr=55#step:6:71 * Update qc.yml Maybe poetry install is not needed at all, since tox recreates its own environment...? * Update qc.yml Forgot to install tox, trying now with pip, alternatively try later with poetry, but not doing a full install * Update qc.yml add missing coverage package * Update qc.yml Try having tox do everything
- Loading branch information
Showing
11 changed files
with
4,057 additions
and
3,412 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import yaml | ||
from pathlib import Path | ||
from typing import List | ||
from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo | ||
from oaklib import get_adapter | ||
|
||
|
||
def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: | ||
""" | ||
Read the raw result file. | ||
Args: | ||
raw_result_path(Path): Path to the raw result file. | ||
Returns: | ||
dict: Contents of the raw result file. | ||
""" | ||
with open(raw_result_path, 'r') as raw_result: | ||
return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list | ||
|
||
|
||
annotator = get_adapter("sqlite:obo:mondo") | ||
some_yaml_res = Path("/Users/leonardo/git/malco/out_openAI_models/raw_results/multimodel/gpt-4/results.yaml") | ||
|
||
data = [] | ||
|
||
if some_yaml_res.is_file(): | ||
all_results = read_raw_result_yaml(some_yaml_res) | ||
j = 0 | ||
for this_result in all_results: | ||
extracted_object = this_result.get("extracted_object") | ||
if extracted_object: # Necessary because this is how I keep track of multiple runs | ||
ontogpt_text = this_result.get("input_text") | ||
# its a single string, should be parseable through curategpt | ||
cleaned_text = clean_service_answer(ontogpt_text) | ||
assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." | ||
result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) | ||
|
||
label = extracted_object.get('label') # pubmed id | ||
# terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream | ||
terms = [i[1][0][0] for i in result] | ||
#terms = extracted_object.get('terms') # list of strings, the mondo id or description | ||
if terms: | ||
# Note, the if allows for rerunning ppkts that failed due to connection issues | ||
# We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field | ||
num_terms = len(terms) | ||
score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank | ||
rank_list = [ i+1 for i in range(num_terms)] | ||
for term, scr, rank in zip(terms, score, rank_list): | ||
data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) | ||
if j>20: | ||
break | ||
j += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import re | ||
import os | ||
from oaklib.interfaces.text_annotator_interface import TextAnnotationConfiguration | ||
from oaklib.interfaces.text_annotator_interface import TextAnnotatorInterface | ||
from curategpt.store import get_store | ||
from typing import List, Tuple | ||
|
||
|
||
# Compile a regex pattern to detect lines starting with "Differential Diagnosis:" | ||
dd_re = re.compile(r"^[^A-z]*Differential Diagnosis") | ||
|
||
# Function to clean and remove "Differential Diagnosis" header if present | ||
def clean_service_answer(answer: str) -> str: | ||
"""Remove the 'Differential Diagnosis' header if present, and clean the first line.""" | ||
lines = answer.split('\n') | ||
# Filter out any line that starts with "Differential Diagnosis:" | ||
cleaned_lines = [line for line in lines if not dd_re.match(line)] | ||
return '\n'.join(cleaned_lines) | ||
|
||
# Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces | ||
def clean_diagnosis_line(line: str) -> str: | ||
"""Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis.""" | ||
line = re.sub(r'^\**\d+\.\s*', '', line) # Remove leading numbers and periods | ||
line = line.strip('*') # Remove asterisks around the text | ||
return line.strip() # Strip any remaining spaces | ||
|
||
# Split a diagnosis into its main name and synonym if present | ||
def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]: | ||
"""Split the diagnosis into main name and synonym (if present in parentheses).""" | ||
match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis) | ||
if match: | ||
main_diagnosis, synonym = match.groups() | ||
return main_diagnosis.strip(), synonym.strip() | ||
return diagnosis, None # Return the original diagnosis if no synonym is found | ||
|
||
def perform_curategpt_grounding( | ||
diagnosis: str, | ||
path: str, | ||
collection: str, | ||
database_type: str = "chromadb", | ||
limit: int = 1, | ||
relevance_factor: float = 0.23, | ||
verbose: bool = False | ||
) -> List[Tuple[str, str]]: | ||
""" | ||
Use curategpt to perform grounding for a given diagnosis when initial attempts fail. | ||
Parameters: | ||
- diagnosis: The diagnosis text to ground. | ||
- path: The path to the database. You'll need to create an index of Mondo using curategpt in this db | ||
- collection: The collection to search within curategpt. Name of mondo collection in the db | ||
NB: You can make this collection by running curategpt thusly: | ||
`curategpt ontology index --index-fields label,definition,relationships -p stagedb -c ont_mondo -m openai: sqlite:obo:mondo` | ||
- database_type: The type of database used for grounding (e.g., chromadb, duckdb). | ||
- limit: The number of search results to return. | ||
- relevance_factor: The distance threshold for relevance filtering. | ||
- verbose: Whether to print verbose output for debugging. | ||
Returns: | ||
- List of tuples: [(Mondo ID, Label), ...] | ||
""" | ||
# Initialize the database store | ||
db = get_store(database_type, path) | ||
|
||
# Perform the search using the provided diagnosis | ||
results = db.search(diagnosis, collection=collection) | ||
|
||
# Filter results based on relevance factor (distance) | ||
if relevance_factor is not None: | ||
results = [(obj, distance, _meta) for obj, distance, _meta in results if distance <= relevance_factor] | ||
|
||
# Limit the results to the specified number (limit) | ||
limited_results = results[:limit] | ||
|
||
# Extract Mondo IDs and labels | ||
pred_ids = [] | ||
pred_labels = [] | ||
|
||
for obj, distance, _meta in limited_results: | ||
disease_mondo_id = obj.get("original_id") # Use the 'original_id' field for Mondo ID | ||
disease_label = obj.get("label") | ||
|
||
if disease_mondo_id and disease_label: | ||
pred_ids.append(disease_mondo_id) | ||
pred_labels.append(disease_label) | ||
|
||
# Return as a list of tuples (Mondo ID, Label) | ||
if len(pred_ids) == 0: | ||
if verbose: | ||
print(f"No grounded IDs found for {diagnosis}") | ||
return [('N/A', 'No grounding found')] | ||
|
||
return list(zip(pred_ids, pred_labels)) | ||
|
||
|
||
# Perform grounding on the text to MONDO ontology and return the result | ||
def perform_oak_grounding( | ||
annotator: TextAnnotatorInterface, | ||
diagnosis: str, | ||
exact_match: bool = True, | ||
verbose: bool = False, | ||
include_list: List[str] = ["MONDO:"], | ||
) -> List[Tuple[str, str]]: | ||
""" | ||
Perform grounding for a diagnosis. The 'exact_match' flag controls whether exact or inexact | ||
(partial) matching is used. Filter results to include only CURIEs that match the 'include_list', | ||
and exclude results that match the 'exclude_list'. | ||
Remove redundant groundings from the result. | ||
""" | ||
config = TextAnnotationConfiguration(matches_whole_text=exact_match) | ||
annotations = list(annotator.annotate_text(diagnosis, configuration=config)) | ||
|
||
# Filter and remove duplicates, while excluding unwanted general terms | ||
filtered_annotations = list( | ||
{ | ||
(ann.object_id, ann.object_label) | ||
for ann in annotations | ||
if any(ann.object_id.startswith(prefix) for prefix in include_list) | ||
} | ||
) | ||
|
||
if filtered_annotations: | ||
return filtered_annotations | ||
else: | ||
match_type = "exact" if exact_match else "inexact" | ||
if verbose: | ||
print(f"No {match_type} grounded IDs found for: {diagnosis}") | ||
pass | ||
return [('N/A', 'No grounding found')] | ||
|
||
# Now, integrate curategpt into your ground_diagnosis_text_to_mondo function | ||
def ground_diagnosis_text_to_mondo( | ||
annotator: TextAnnotatorInterface, | ||
differential_diagnosis: str, | ||
verbose: bool = False, | ||
include_list: List[str] = ["MONDO:"], | ||
use_ontogpt_grounding: bool = True, | ||
curategpt_path: str = "../curategpt/stagedb/", | ||
curategpt_collection: str = "ont_mondo", | ||
curategpt_database_type: str = "chromadb" | ||
) -> List[Tuple[str, List[Tuple[str, str]]]]: | ||
results = [] | ||
|
||
# Split the input into lines and process each one | ||
for line in differential_diagnosis.splitlines(): | ||
clean_line = clean_diagnosis_line(line) | ||
|
||
# Skip header lines like "**Differential diagnosis:**" | ||
if not clean_line or "Differential diagnosis" in clean_line.lower(): | ||
continue | ||
|
||
# Try grounding the full line first (exact match) | ||
grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list) | ||
|
||
# Try grounding with curategpt if no grounding is found | ||
if use_ontogpt_grounding and grounded == [('N/A', 'No grounding found')]: | ||
grounded = perform_curategpt_grounding( | ||
diagnosis=clean_line, | ||
path=curategpt_path, | ||
collection=curategpt_collection, | ||
database_type=curategpt_database_type, | ||
verbose=verbose | ||
) | ||
|
||
# If still no grounding is found, log the final failure | ||
if grounded == [('N/A', 'No grounding found')]: | ||
if verbose: | ||
print(f"Final grounding failed for: {clean_line}") | ||
|
||
# Append the grounded results (even if no grounding was found) | ||
results.append((clean_line, grounded)) | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.