Skip to content

Commit

Permalink
init synthtree with custom llm. make raw_prods a property
Browse files Browse the repository at this point in the history
  • Loading branch information
doncamilom committed Jan 13, 2024
1 parent cefe0e9 commit 9ba3b88
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,6 @@ FodyWeavers.xsd
# End of https://www.toptal.com/developers/gitignore/api/macos,linux,windows,python,jupyternotebooks,jetbrains,pycharm,vim,emacs,visualstudiocode,visualstudio

scratch/
notebooks/see_trees/
_.py
data/
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ recursive-include docs/source *.png
global-exclude *.py[cod] __pycache__ *.so *.dylib .DS_Store *.gpickle

include README.md LICENSE
exclude tox.ini .bumpversion.cfg .readthedocs.yml .cruft.json CITATION.cff docker-compose.yml Dockerfile
exclude tox.ini .bumpversion.cfg .readthedocs.yml .cruft.json CITATION.cff docker-compose.yml Dockerfile *.ipynb
15 changes: 8 additions & 7 deletions src/jasyntho/doc_extract/synthdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
self,
doc_src: str,
api_key: Optional[str] = None,
model: str = "gpt-4-0314",
startp: int = 0,
endp: Optional[int] = None,
verbose: bool = True,
Expand All @@ -41,23 +42,23 @@ def __init__(
api_key = api_key or os.environ["OPENAI_API_KEY"]

self.v = verbose
self.rxn_extract = Extractor("rxn_setup", api_key)
self.rxn_extract = Extractor("rxn_setup", api_key, model=model)
self.paragraphs = self._get_paragraphs(doc_src, start=startp, end=endp)

def extract_rss(self) -> list:
"""Extract reaction setups for each paragraph in the doc."""
ext = [p.extract(self.rxn_extract) for p in self.paragraphs]
self._report_process(ext)
products = [p for p in ext if not p.isempty()]
self.raw_prods = [p.extract(self.rxn_extract) for p in self.paragraphs]
self._report_process(self.raw_prods)
products = [p for p in self.raw_prods if not p.isempty()]
return products

async def async_extract_rss(self) -> list:
"""Extract reaction setups for each paragraph in the doc."""
ext = await asyncio.gather(
self.raw_prods = await asyncio.gather(
*[p.async_extract(self.rxn_extract) for p in self.paragraphs]
)
self._report_process(ext)
products = [p for p in ext if not p.isempty()]
self._report_process(self.raw_prods)
products = [p for p in self.raw_prods if not p.isempty()]
return products

def _report_process(self, raw_prods) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/jasyntho/doc_extract/synthtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ def __init__(
self,
doc_src: str,
api_key: Optional[str] = None,
model: str = "gpt-4-0314",
startp: int = 0,
endp: Optional[int] = None,
) -> None:
"""Initialize a SynthTree object."""
super(SynthTree, self).__init__(
doc_src, api_key, startp, endp
doc_src, api_key, model, startp, endp
) # TODO: select startp and endp automatically from doc_src

def build_tree(self):
Expand Down
18 changes: 14 additions & 4 deletions src/jasyntho/extract/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@ class Extractor:
Initializes extractor depending on the snippet class.
"""

def __init__(self, sclass: str, api_key: Optional[str] = None) -> None:
def __init__(
self,
sclass: str,
api_key: Optional[str] = None,
model: str = "gpt-4-0314",
) -> None:
"""Initialize extractor.
Input
sclass : str
Snippet class.
One of 'rxn_setup', 'rxn_workup', 'purification', 'analysis'
"""
self.extractor = self._init_extractor(sclass, api_key)
self.extractor = self._init_extractor(sclass, api_key, model)

def __call__(self, text: str) -> Any:
"""Execute the extractor."""
Expand All @@ -29,7 +34,12 @@ async def async_call(self, text: str) -> Any:
"""Execute extractor."""
return await self.extractor.async_call(text)

def _init_extractor(self, eclass: str, api_key: Optional[str] = None):
def _init_extractor(
self,
eclass: str,
api_key: Optional[str] = None,
model: str = "gpt-4-0314",
):
"""
Initialize a data extractor.
Expand All @@ -38,7 +48,7 @@ def _init_extractor(self, eclass: str, api_key: Optional[str] = None):
Type of extractor to initialize.
"""
if eclass == "rxn_setup":
return ReactionSetup()
return ReactionSetup(api_key=api_key, model=model)
elif eclass == "rxn_workup":
raise NotImplementedError()
elif eclass == "purification":
Expand Down
9 changes: 5 additions & 4 deletions src/jasyntho/extract/rxn_setup/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
class ReactionSetup:
"""Extraction of structured data from reaction-setup snippet."""

def __init__(self, api_key=None):
def __init__(self, api_key=None, model="gpt-4-0314"):
"""Initialize the extractor."""
load_dotenv()

self.llm = "gpt-3.5-turbo"
self.llm = "gpt-4-1106-preview"
# self.llm = "gpt-4"
self.llm = model
self.client = instructor.patch(OpenAI())
self.aclient = instructor.apatch(AsyncOpenAI())

def __call__(self, text: str) -> Product:
"""Execute the extraction pipeline for a single paragraph."""
print(text)
print(self.client)
print(self.llm)
product = Product.from_paragraph(text, self.client, self.llm)
return product

Expand Down
13 changes: 10 additions & 3 deletions src/jasyntho/extract/rxn_setup/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ class Substance(BaseModel):
"""A substance in a reaction."""

reference_key: Optional[str] = Field(
description=("Identifier for a substance described in text. "),
description=(
"Identifier for a substance described in text. "
"It can be a number or combination of numbers and letters."
),
)
substance_name: str = Field(
description="Name of the substance.",
)
role: Literal["reactant", "work-up", "solvent", "product"] = Field(
role: Literal["reactant", "work-up", "solvent", "product", "intermediate"] = Field(
description=(
"What is the role of the substance in the reaction. "
"'work-up' is reserved for substances used in subsequent "
Expand Down Expand Up @@ -111,6 +114,10 @@ def from_substancelist(cls, slist: SubstanceList):
for s in same_key:
if s.role != "reactant":
clean_ch.remove(s)
else:
for s in same_key[1:]: # remove all but first
clean_ch.remove(s)

child_final = [Substance.from_lm(s) for s in clean_ch]

return cls(
Expand Down Expand Up @@ -159,7 +166,7 @@ async def async_from_paragraph(
)
prd = cls.from_substancelist(subs_list)
except (openai.APITimeoutError, ValidationError) as e: # type: ignore
if isinstance(e, openai.APITimeoutError): # type: ignore
if isinstance(e, penai.APITimeoutError): # type: ignore
prd = cls.empty(note=e.message)
else:
prd = cls.empty(note="Validation error.")
Expand Down
7 changes: 6 additions & 1 deletion tests/test_synthtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
def ex_tree():
"""Initialize document."""
oai_key = os.getenv("OPENAI_API_KEY")
doc = SynthTree("tests/examples/synth_SI_sub.pdf", oai_key)
print(oai_key)
doc = SynthTree(
"tests/examples/synth_SI_sub.pdf",
oai_key,
model="gpt-4-0613",
)
doc.build_tree()
return doc

Expand Down

0 comments on commit 9ba3b88

Please sign in to comment.