Skip to content

Commit

Permalink
keep working tomorrow...
Browse files Browse the repository at this point in the history
  • Loading branch information
SalvadorBrandolin committed Sep 10, 2024
1 parent 5f3719e commit 35a8810
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 25 deletions.
122 changes: 114 additions & 8 deletions ugropy/fragmentation_models/fragmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
All ugropy models (joback, unifac, psrk) are instances of the
FragmentationModule class.
"""

from abc import ABC, abstractmethod

import pandas as pd

from rdkit import Chem
from rdkit.Chem import Draw
from PIL import Image, ImageDraw, ImageFont
import numpy as np


class FragmentationModel(ABC):
Expand Down Expand Up @@ -38,15 +42,14 @@ class FragmentationModel(ABC):

def __init__(self, subgroups: pd.DataFrame) -> None:
self.subgroups = subgroups

# Instantiate all de mol object from their smarts representation
detection_mols = {}

for group, row in self.subgroups.iterrows():
detection_mols[group] = Chem.MolFromSmarts(row["smarts"])


def detect_groups(self, molecule: Chem.Mol) -> pd.DataFrame:
def detect_groups(self, molecule: Chem.rdchem.Mol) -> pd.DataFrame:
"""Detect all the groups in the molecule.
Return a dictionary with the detected groups as keys and a tuple of
Expand All @@ -55,7 +58,7 @@ def detect_groups(self, molecule: Chem.Mol) -> pd.DataFrame:
Parameters
----------
mol : Chem.Mol
mol : Chem.rdchem.Mol
Molecule to detect the groups.
Returns
Expand All @@ -66,18 +69,121 @@ def detect_groups(self, molecule: Chem.Mol) -> pd.DataFrame:
detected_groups = {}
for group, mol in self.detection_mols.items():
matches = molecule.GetSubstructMatches(mol)

if matches:
detected_groups[group] = matches

return detected_groups

@abstractmethod
def set_fragmentation_result(
self,
molecule: Chem.Mol,
molecule: Chem.rdchem.Mol,
subgroups_occurrences: dict,
subgroups_atoms_indexes: dict,
) -> "FragmentationResult":

raise NotImplementedError("Abstract Method not implemented.")


class FragmentationResult:
def __init__(self, molecule: Chem.rdchem.Mol, subgroups: dict):
self.mol_object = molecule
self.subgroups = subgroups

def draw(
mol_object: Chem.rdchem.Mol,
subgroups: dict,
model, # El tipo de model depende de tu implementación
title: str = "",
width: int = 400,
height: int = 200,
title_font_size: int = 12,
legend_font_size: int = 12,
font: str = "Helvetica",
) -> Image.Image:
"""Create a PIL image of the fragmentation result with a legend."""

# Ajustar los subgrupos a los átomos de la molécula
fit = fit_atoms(mol_object, subgroups, model)

# Generar los colores para cada subgrupo
how_many_subgroups = len(fit.keys())
colors_rgb = _generate_distinct_colors(how_many_subgroups)

highlight = []
atoms_colors = {}

for idx, (subgroup, atoms) in enumerate(fit.items()):
atms = np.array(atoms).flatten()
highlight.extend(atms.tolist())

for at in atms:
atoms_colors[int(at)] = colors_rgb[idx]

# Crear la imagen de la molécula usando MolToImage
img = Draw.MolToImage(
mol_object,
size=(width, height),
highlightAtoms=highlight,
highlightAtomColors={i: tuple(c[:3]) for i, c in atoms_colors.items()} # RGB sin el canal alpha
)

# Crear una imagen más grande para añadir la leyenda y el título
new_height = height + (legend_font_size + 5) * how_many_subgroups + 50 # Espacio para la leyenda y título
final_img = Image.new("RGB", (width, new_height), "white")

# Pegar la imagen de la molécula en la parte superior
final_img.paste(img, (0, 0))

# Crear un objeto ImageDraw para dibujar la leyenda y el título
draw = ImageDraw.Draw(final_img)

# Cargar la fuente o usar la predeterminada si no está disponible
try:
font_title = ImageFont.truetype("arial.ttf", title_font_size)
font_legend = ImageFont.truetype("arial.ttf", legend_font_size)
except IOError:
font_title = ImageFont.load_default()
font_legend = ImageFont.load_default()

# Dibujar el título
draw.text((width / 2, height + 10), title, fill="black", font=font_title, anchor="ms")

# Dibujar la leyenda
for i, (subgroup, color) in enumerate(atoms_colors.items()):
r, g, b = [int(255 * c) for c in color[:3]] # Convertir a RGB 0-255
rect_color = (r, g, b)

# Dibujar un cuadro de color para cada subgrupo
draw.rectangle([10, height + 50 + i * (legend_font_size + 5), 30, height + 50 + (i + 1) * (legend_font_size + 5)], fill=rect_color)

# Dibujar el nombre del subgrupo
draw.text((40, height + 50 + i * (legend_font_size + 5)), f"Subgrupo {i + 1}", fill="black", font=font_legend)

return final_img

# Generar colores distintos (como antes)
def _generate_distinct_colors(n: int) -> list:
base_colors = np.array(
[
[0.12156863, 0.46666667, 0.70588235], # azul
[1.0, 0.49803922, 0.05490196], # naranja
[0.17254902, 0.62745098, 0.17254902], # verde
[0.83921569, 0.15294118, 0.15686275], # rojo
[0.58039216, 0.40392157, 0.74117647], # púrpura
[0.54901961, 0.3372549, 0.29411765], # marrón
[0.89019608, 0.46666667, 0.76078431], # rosa
[0.49803922, 0.49803922, 0.49803922], # gris
[0.7372549, 0.74117647, 0.13333333], # amarillo
[0.09019608, 0.74509804, 0.81176471], # cian
]
)
colors = [base_colors[i % len(base_colors)] for i in range(n)]
return [(color[0], color[1], color[2], 0.65) for color in colors]

# Función de prueba
mol = Chem.MolFromSmiles('CCO')
subgroups = {'Grupo 1': [0, 1], 'Grupo 2': [2]} # Ejemplo
img_with_legend = draw(mol, subgroups, model=None, title="Molécula con subgrupos")
img_with_legend.show()
2 changes: 1 addition & 1 deletion ugropy/refactor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .fragmentation_unifac import unifac2


__all__ = ["Fragment", "FragmentationModel", "unifac"]
__all__ = ["Fragment", "FragmentationModel", "unifac"]
2 changes: 1 addition & 1 deletion ugropy/refactor/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ class Fragment:
def __init__(self, name: str, smarts: str):
self.name = name
self.smarts = smarts
self.mol_object = Chem.MolFromSmarts(smarts)
self.mol_object = Chem.MolFromSmarts(smarts)
28 changes: 13 additions & 15 deletions ugropy/refactor/fragmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def detect_fragments(self, molecule: Chem.rdchem.Mol):
if match:
batch.add_fragment(fragment.name, match)
return batch



class DetectionBatch:
def __init__(self, molecule: Chem.rdchem.Mol):
def __init__(self, molecule: Chem.rdchem.Mol):
self.n = molecule.GetNumAtoms()
self.fragments = {}
self.overlaped_fragments = {}
Expand All @@ -37,7 +38,6 @@ def __init__(self, molecule: Chem.rdchem.Mol):
self.solution = {}

self.has_overlap = False


def get_groups(self):
self.build_overlap_matrix()
Expand All @@ -59,7 +59,6 @@ def get_groups(self):
for frag in self.solution_atoms.keys():
self.solution[frag] = len(self.solution_atoms[frag])


def add_fragment(self, fragment_name: str, fragments: tuple):
for i, f in enumerate(fragments):
self.fragments[f"{fragment_name}_{i}"] = list(f)
Expand All @@ -81,7 +80,9 @@ def get_overlaped_fragments(self):
def solve_overlap(self):
universe = set(self.overlaped_atoms)

all_elements = set(itertools.chain.from_iterable(self.overlaped_fragments.values()))
all_elements = set(
itertools.chain.from_iterable(self.overlaped_fragments.values())
)

universe.update(all_elements)

Expand All @@ -98,21 +99,18 @@ def solve_overlap(self):
for i, subset in enumerate(self.overlaped_fragments.values()):
if elem in subset:
sum_list.append(x[i])

# print(f"Restricción para el elemento {elem}: {sum_list} == 1")
problem += pulp.lpSum(sum_list) == 1

solver = pulp.getSolver('PULP_CBC_CMD', msg=False)
solver = pulp.getSolver("PULP_CBC_CMD", msg=False)

problem.solve(solver)

selected_subsets = [name for i, name in enumerate(self.overlaped_fragments.keys()) if pulp.value(x[i]) == 1]
selected_subsets = [
name
for i, name in enumerate(self.overlaped_fragments.keys())
if pulp.value(x[i]) == 1
]

self.selected_fragments = selected_subsets







0 comments on commit 35a8810

Please sign in to comment.