Skip to content

Commit

Permalink
Changing to instructor (#5)
Browse files Browse the repository at this point in the history
* replaced llmchain by instructor

* remove CDE, all in pydantic now + async implementation along with serial

* fixing some problems with typing implem. handle empty prods

* moving all to pydantic models. extracting reasonable trees

* add report of paragraphs processed. Starting to connect nodes w NL!

* mod tests. needs rewriting

* up python ver. 3.8 doesnt work for instructor
  • Loading branch information
doncamilom authored Nov 27, 2023
1 parent 09651e2 commit bdad464
Show file tree
Hide file tree
Showing 31 changed files with 751 additions and 623 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9"]
python-version: [ "3.10", "3.9"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -40,7 +40,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9"]
python-version: [ "3.10", "3.9"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -63,7 +63,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [ "3.8", "3.9"]
python-version: [ "3.10", "3.9"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
18 changes: 13 additions & 5 deletions notebooks/explor_chem_nlp/00_preprocess.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@
"\n",
"# Load original uspto csv with all fields.\n",
"uspto_csv_1 = pd.read_csv(\n",
" \"../../data/raw/Extracted_Data_2001_Sep2016_USPTOapplications_new.csv\", low_memory=False\n",
" \"../../data/raw/Extracted_Data_2001_Sep2016_USPTOapplications_new.csv\",\n",
" low_memory=False,\n",
")\n",
"uspto_csv_2 = pd.read_csv(\n",
" \"../../data/raw/Extracted_Data_1976_Sep2016_USPTOgrants_new.csv\", low_memory=False\n",
" \"../../data/raw/Extracted_Data_1976_Sep2016_USPTOgrants_new.csv\",\n",
" low_memory=False,\n",
")\n",
"uspto_csv = pd.concat([uspto_csv_1, uspto_csv_2])\n",
"\n",
Expand Down Expand Up @@ -207,7 +209,9 @@
"prg_rxn_sort = (\n",
" parags_db.reset_index()\n",
" .merge(\n",
" uspto_csv[[\"Paragraph Text\", \"Reaction Smiles\", \"Product List\"]].drop_duplicates(),\n",
" uspto_csv[\n",
" [\"Paragraph Text\", \"Reaction Smiles\", \"Product List\"]\n",
" ].drop_duplicates(),\n",
" right_on=\"Paragraph Text\",\n",
" left_on=\"Paragraph Text\",\n",
" how=\"outer\",\n",
Expand Down Expand Up @@ -386,7 +390,9 @@
"\n",
"\n",
"edit_distances = np.array([quality_filter(i) for i in segm_db.keys()])\n",
"print(f\"Edit distances of first 10 paragraphs/segmentations: {edit_distances[:10]}\")\n",
"print(\n",
" f\"Edit distances of first 10 paragraphs/segmentations: {edit_distances[:10]}\"\n",
")\n",
"\n",
"print(f\"\\n\\n{'edd':>4}{'Number of samples':>21}\\n\")\n",
"for t in [0, 5, 10, 50, 100, 500, 1000]:\n",
Expand All @@ -401,7 +407,9 @@
" # Simply repeat the value of remove_prgs[p] for each segment\n",
" keep_idx += [remove_prgs[j] for i in segm_db[p]]\n",
"\n",
"print(f\"\\nlen of discriminator (should be same as current len of sg_db): {len(keep_idx)}\\n\")\n",
"print(\n",
" f\"\\nlen of discriminator (should be same as current len of sg_db): {len(keep_idx)}\\n\"\n",
")\n",
"print(f\"len of sg_db: {len(sg_db)}\\n\")"
]
},
Expand Down
14 changes: 11 additions & 3 deletions notebooks/explor_chem_nlp/01_map_embedds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@
],
"source": [
"# First select rxn_setup only\n",
"sg_db_setup_d = {i: b for i, b in enumerate(sg_db) if b[\"sgm_cls\"] == \"reaction set-up\"}\n",
"sg_db_setup_d = {\n",
" i: b for i, b in enumerate(sg_db) if b[\"sgm_cls\"] == \"reaction set-up\"\n",
"}\n",
"sg_db_setup = [b for b in sg_db_setup_d.values()]\n",
"\n",
"# Select a batch size\n",
Expand Down Expand Up @@ -182,7 +184,11 @@
"def embed_batch(sentences):\n",
" # Tokenize sentences\n",
" encoded_input = tokenizer(\n",
" sentences, padding=True, truncation=True, return_tensors=\"pt\", max_length=256\n",
" sentences,\n",
" padding=True,\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" max_length=256,\n",
" ).to(device)\n",
"\n",
" # Compute token embeddings\n",
Expand All @@ -197,7 +203,9 @@
" torch.cuda.empty_cache()\n",
"\n",
" # normalize embeddings\n",
" sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)\n",
" sentence_embeddings = torch.nn.functional.normalize(\n",
" sentence_embeddings, p=2, dim=1\n",
" )\n",
" return sentence_embeddings.to(\"cpu\")\n",
"\n",
"\n",
Expand Down
57 changes: 43 additions & 14 deletions notebooks/explor_chem_nlp/02_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
" \"\"\"\n",
" fig, ax = plt.subplots(figsize=(15, 15))\n",
"\n",
" sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=labels, ax=ax, s=s, marker=\".\", palette=\"hsv\")\n",
" sns.scatterplot(\n",
" x=X[:, 0], y=X[:, 1], hue=labels, ax=ax, s=s, marker=\".\", palette=\"hsv\"\n",
" )\n",
"\n",
" for side in [\"left\", \"right\", \"top\", \"bottom\"]:\n",
" ax.spines[side].set_visible(False)\n",
Expand Down Expand Up @@ -432,7 +434,9 @@
"\n",
"unique_smis[\"canon_smi\"] = unique_smis[\"rxn_smi\"].progress_apply(preproc_rxn)\n",
"\n",
"unique_smis[\"canon_smi\"].to_csv(\"../../data/processed/uniq_rxnsmi.smi\", index=False, header=None)"
"unique_smis[\"canon_smi\"].to_csv(\n",
" \"../../data/processed/uniq_rxnsmi.smi\", index=False, header=None\n",
")"
]
},
{
Expand All @@ -444,12 +448,17 @@
"source": [
"# After mapping with namerxn, load file\n",
"rxn_classes = pd.read_csv(\n",
" \"../../data/processed/uniq_rxnsmi_cls.smi\", header=None, names=[\"rxn_smi\", \"class\"], sep=\" \"\n",
" \"../../data/processed/uniq_rxnsmi_cls.smi\",\n",
" header=None,\n",
" names=[\"rxn_smi\", \"class\"],\n",
" sep=\" \",\n",
")\n",
"unique_smis[\"class\"] = rxn_classes[\"class\"]\n",
"\n",
"# Merge with sg_df\n",
"sg_df_ = pd.merge(sg_df, unique_smis, left_on=\"rxn_smi\", right_on=\"rxn_smi\", how=\"left\")"
"sg_df_ = pd.merge(\n",
" sg_df, unique_smis, left_on=\"rxn_smi\", right_on=\"rxn_smi\", how=\"left\"\n",
")"
]
},
{
Expand Down Expand Up @@ -478,7 +487,9 @@
"}\n",
"\n",
"# Extract only first class\n",
"sg_df_[\"cls_1\"] = sg_df_[\"class\"].apply(lambda x: rxn_class_map[int(x.split(\".\")[0])])"
"sg_df_[\"cls_1\"] = sg_df_[\"class\"].apply(\n",
" lambda x: rxn_class_map[int(x.split(\".\")[0])]\n",
")"
]
},
{
Expand Down Expand Up @@ -616,7 +627,9 @@
"\n",
" Xs = torch.load(f\"../../data/processed/umap/UMAP_{cl_}_full.pt\")\n",
" labels = sg_df_.loc[sg_df_[\"sgm_cls\"] == cl, \"cls_1\"].values\n",
" fig = plot_umap(Xs, labels=labels, save=f\"../plots/{cl}_full_rxncls.png\", dpi=600)\n",
" fig = plot_umap(\n",
" Xs, labels=labels, save=f\"../plots/{cl}_full_rxncls.png\", dpi=600\n",
" )\n",
" figs.append(fig)\n",
"\n",
"# Also for rxnfps\n",
Expand Down Expand Up @@ -652,10 +665,12 @@
"# Here are the datasets that contain all data - including yield\n",
"\n",
"uspto_csv_1 = pd.read_csv(\n",
" \"../../data/raw/Extracted_Data_2001_Sep2016_USPTOapplications_new.csv\", low_memory=False\n",
" \"../../data/raw/Extracted_Data_2001_Sep2016_USPTOapplications_new.csv\",\n",
" low_memory=False,\n",
")\n",
"uspto_csv_2 = pd.read_csv(\n",
" \"../../data/raw/Extracted_Data_1976_Sep2016_USPTOgrants_new.csv\", low_memory=False\n",
" \"../../data/raw/Extracted_Data_1976_Sep2016_USPTOgrants_new.csv\",\n",
" low_memory=False,\n",
")\n",
"uspto_csv = pd.concat([uspto_csv_1, uspto_csv_2]).drop_duplicates(\n",
" subset=[\"Reaction Smiles\", \"Yield\"]\n",
Expand Down Expand Up @@ -719,7 +734,11 @@
" return np.nanmax(v)\n",
"\n",
"\n",
"yields = uspto_csv.groupby(\"Reaction Smiles\")[\"Yield\"].apply(proc_yield).reset_index()\n",
"yields = (\n",
" uspto_csv.groupby(\"Reaction Smiles\")[\"Yield\"]\n",
" .apply(proc_yield)\n",
" .reset_index()\n",
")\n",
"\n",
"yields.head"
]
Expand Down Expand Up @@ -871,7 +890,9 @@
],
"source": [
"# Merge with sg_df\n",
"sg_df_ = pd.merge(sg_df, yields, left_on=\"rxn_smi\", right_on=\"Reaction Smiles\", how=\"left\")\n",
"sg_df_ = pd.merge(\n",
" sg_df, yields, left_on=\"rxn_smi\", right_on=\"Reaction Smiles\", how=\"left\"\n",
")\n",
"\n",
"sg_df_[\"Yield bin\"] = pd.cut(sg_df_[\"Yield\"].fillna(0), 10)\n",
"sg_df_.head()"
Expand Down Expand Up @@ -970,7 +991,9 @@
" Xs = Xs[discr]\n",
" labels = pd.cut(labels[discr], 9)\n",
"\n",
" fig = plot_umap(Xs, labels=labels, save=f\"../plots/{cl}_full_yield.png\", dpi=600)\n",
" fig = plot_umap(\n",
" Xs, labels=labels, save=f\"../plots/{cl}_full_yield.png\", dpi=600\n",
" )\n",
" figs.append(fig)\n",
"\n",
"# Also for rxnfps\n",
Expand Down Expand Up @@ -1078,7 +1101,9 @@
"\n",
" def chunk(t):\n",
" n = 90\n",
" return \"\\n\".join([t[i * n : (i + 1) * n] for i in range(len(t) // n + 1)])\n",
" return \"\\n\".join(\n",
" [t[i * n : (i + 1) * n] for i in range(len(t) // n + 1)]\n",
" )\n",
"\n",
" df.txt_prg = df.txt_prg.apply(chunk)\n",
"\n",
Expand All @@ -1098,7 +1123,9 @@
" height=1800,\n",
" )\n",
" # Change the marker's border color and width\n",
" fig.update_traces(marker=dict(line=dict(width=0.001, color=\"DarkSlateGrey\")))\n",
" fig.update_traces(\n",
" marker=dict(line=dict(width=0.001, color=\"DarkSlateGrey\"))\n",
" )\n",
" fig.update_xaxes(showticklabels=False)\n",
" fig.update_yaxes(showticklabels=False)\n",
"\n",
Expand Down Expand Up @@ -1155,7 +1182,9 @@
"Xs = Xs[discr]\n",
"labels = pd.cut(labels[discr], 9)\n",
"\n",
"fig = plot_umap(Xs, labels=labels, save=f\"../plots/{cl}_full_yield.png\", dpi=600, s=0.3)\n",
"fig = plot_umap(\n",
" Xs, labels=labels, save=f\"../plots/{cl}_full_yield.png\", dpi=600, s=0.3\n",
")\n",
"figs.append(fig)"
]
},
Expand Down
9 changes: 7 additions & 2 deletions notebooks/explor_chem_nlp/03_nomic_atlas.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,20 @@
" meta_ss = [meta[i] for i in sampl_idx]\n",
"\n",
" smis = pd.Series([b[\"rxn_smi\"] for b in meta_ss], name=\"rxn_smi\")\n",
" rxn_clss = rcls_dd.merge(smis, left_on=\"orig_smi\", right_on=\"rxn_smi\", how=\"right\").cls_1\n",
" rxn_clss = rcls_dd.merge(\n",
" smis, left_on=\"orig_smi\", right_on=\"rxn_smi\", how=\"right\"\n",
" ).cls_1\n",
"\n",
" def pp(b, cls):\n",
" b[\"prd_str\"] = str(b[\"prd_str\"])\n",
" b[\"rxn_img\"] = cdk(b[\"rxn_smi\"])\n",
" b[\"rxn_cls\"] = cls\n",
" return b\n",
"\n",
" meta_ss = [pp(b, cls) for b, cls in tqdm(zip(meta_ss, rxn_clss), total=len(meta_ss))]\n",
" meta_ss = [\n",
" pp(b, cls)\n",
" for b, cls in tqdm(zip(meta_ss, rxn_clss), total=len(meta_ss))\n",
" ]\n",
"\n",
" response = atlas.map_embeddings(\n",
" name=f\"Semantic synthesis: {cl}\",\n",
Expand Down
39 changes: 30 additions & 9 deletions notebooks/explor_chem_nlp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def sequence_calculator(seq1, seq2, list_examples):
total = total + value

output_string = (
"The number of sequence starting with " + seq1 + " and ending with " + seq2 + " is: "
"The number of sequence starting with "
+ seq1
+ " and ending with "
+ seq2
+ " is: "
)

return total
Expand All @@ -78,7 +82,9 @@ def sequence_matrix_creator(list_examples_quar):

for i in range(len(char_list)):
for j in range(len(char_list)):
output = sequence_calculator(char_list[i], char_list[j], list_examples_quar)
output = sequence_calculator(
char_list[i], char_list[j], list_examples_quar
)
ls.append(output)

arr = np.array(ls)
Expand Down Expand Up @@ -196,7 +202,9 @@ def rxn_setup_work_up_text(segm_parags):
text_segm = segm["text segment"][1:-1]

if cls == "reaction set-up":
if flag == 0: # a work-up is not in the string, rxn set-up can be concat to the string
if (
flag == 0
): # a work-up is not in the string, rxn set-up can be concat to the string
if rxn_set_up_str == "":
rxn_set_up_str += text_segm
else:
Expand Down Expand Up @@ -266,13 +274,16 @@ def segms_compress(segms):
for i in range(0, len(segms) - 1):
if class_check(segms[i]["text class"], segms[i + 1]["text class"]):
if str_temp != "":
str_temp = str_combine(str_temp, segms[i + 1]["text segment"][1:-1])
str_temp = str_combine(
str_temp, segms[i + 1]["text segment"][1:-1]
)
text_class = text_class
step_order = step_order

else:
str_temp = str_combine(
segms[i]["text segment"][1:-1], segms[i + 1]["text segment"][1:-1]
segms[i]["text segment"][1:-1],
segms[i + 1]["text segment"][1:-1],
)
text_class = segms[i]["text class"]
step_order = str(int(segms[i]["step order"]) - i + len(ls))
Expand All @@ -292,13 +303,19 @@ def segms_compress(segms):
dict_temp = {
"text segment": segms[i]["text segment"],
"text class": segms[i]["text class"],
"step order": str(int(segms[i]["step order"]) - i + len(ls)),
"step order": str(
int(segms[i]["step order"]) - i + len(ls)
),
}
ls.append(dict_temp)

if str_temp != "":
ls.append(
{"text segment": str_temp, "text class": text_class, "step order": step_order}
{
"text segment": str_temp,
"text class": text_class,
"step order": step_order,
}
)

else: # r-w-r-w
Expand Down Expand Up @@ -360,7 +377,9 @@ def check_segment_format(string):

# Split a string without removing the delimiter
substring_list = [
substring + add_delimiter for substring in string.split(remove_delimiter) if substring
substring + add_delimiter
for substring in string.split(remove_delimiter)
if substring
]

# Removing the trailing delimiter "}" in the last element
Expand Down Expand Up @@ -408,7 +427,9 @@ def check_edit_distance(parag, string):

# Split a string without removing the delimiter
substring_list = [
substring + add_delimiter for substring in string.split(remove_delimiter) if substring
substring + add_delimiter
for substring in string.split(remove_delimiter)
if substring
]

# Removing the trailing delimiter "}" in the last element
Expand Down
4 changes: 3 additions & 1 deletion scripts/map_full_uspto.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,6 @@

if i % backup_freq == 0:
pickle.dump(segm_map, f)
print(f"Last backup: {i}th epoch. Processed {len(segm_map)} samples so far")
print(
f"Last backup: {i}th epoch. Processed {len(segm_map)} samples so far"
)
Loading

0 comments on commit bdad464

Please sign in to comment.