Skip to content

Commit

Permalink
multiple bugfixes for fragment generator
Browse files Browse the repository at this point in the history
  • Loading branch information
menoliu committed May 22, 2024
1 parent 05f1307 commit 3123337
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 30 deletions.
97 changes: 67 additions & 30 deletions src/idpconfgen/cli_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
"""
import argparse
import pickle
import random
import re
from collections import defaultdict
from collections import Counter, defaultdict
from functools import partial
from itertools import combinations, cycle, product
from multiprocessing import Pool
from random import randint, random

import numpy as np

Expand Down Expand Up @@ -106,6 +106,7 @@
)
from idpconfgen.libs.libmulticore import pool_function
from idpconfgen.libs.libparse import (
count_occurrences,
get_trimer_seq_njit,
remap_sequence,
remove_empty_keys,
Expand Down Expand Up @@ -874,12 +875,13 @@ def main(
custom_contacts_weight = np.clip(custom_contacts_weight, 0, 100)
min_contacts_weight = custom_contacts_weight / 100.0
log.info(T(f'Choosing at most {max_contacts} contacts for every conformer. Custom-contacts will be chosen with a probability of {custom_contacts_weight} %')) # noqa: E501
for _ in range(nconfs):
for n in range(nconfs):
for i, norm_inter_mtx in enumerate(blended_inter_mtxs):
case = combo_chains[i]
contacts_counter = randint(1, max_contacts)
random.seed(random_seed + n)
contacts_counter = random.randint(1, max_contacts)
while contacts_counter - 1 > 0:
custom = random() < min_contacts_weight
custom = random.random() < min_contacts_weight
if not custom or cus_inter_res[i] is None:
x_coords, y_coords = select_contacts(
coords=norm_inter_mtx,
Expand All @@ -900,10 +902,11 @@ def main(
contacts_counter -= len(pair1)
if ignore_intra is False:
for i, norm_intra_mtx in enumerate(blended_intra_mtxs):
contacts_counter = randint(1, max_contacts)
random.seed(random_seed + n)
contacts_counter = random.randint(1, max_contacts)
case = input_seq_keys[i]
while contacts_counter - 1 > 0:
custom = random() < min_contacts_weight
custom = random.random() < min_contacts_weight
if not custom or list(cus_intra_res.values())[i] is None: # noqa: E501
x_coords, y_coords = select_contacts(
coords=norm_intra_mtx,
Expand All @@ -925,9 +928,10 @@ def main(
log.info(S('done'))
else:
log.info(T(f'Choosing at most {max_contacts} contacts for every conformer from the blended contact heatmap.')) # noqa: E501
for _ in range(nconfs):
for n in range(nconfs):
for i, norm_inter_mtx in enumerate(blended_inter_mtxs):
contacts_counter = randint(1, max_contacts)
random.seed(random_seed + n)
contacts_counter = random.randint(1, max_contacts)
case = combo_chains[i]
while contacts_counter > 0:
x_coords, y_coords = select_contacts(
Expand All @@ -939,7 +943,8 @@ def main(
contacts_counter -= len(x_coords)
if ignore_intra is False:
for i, norm_intra_mtx in enumerate(blended_intra_mtxs):
contacts_counter = randint(1, max_contacts)
random.seed(random_seed + n)
contacts_counter = random.randint(1, max_contacts)
case = input_seq_keys[i]
while contacts_counter > 0:
x_coords, y_coords = select_contacts(
Expand All @@ -950,7 +955,7 @@ def main(
selected_contacts["Y"][contact_type[0]][case].append(y_coords) # noqa: E501
contacts_counter -= len(x_coords)
log.info(S('done'))

# NOTE work with generalizable inter- for IDP-Folded and IDP-IDP before
# algorithm for intramolecular contacts
for conf in range(nconfs):
Expand All @@ -963,39 +968,57 @@ def main(
inter_x_coords = selected_contacts["X"][contact_type[1]][chains]
inter_y_coords = selected_contacts["Y"][contact_type[1]][chains]

for i, x_coords in enumerate(inter_x_coords):
xy = []
y_coords = inter_y_coords[i]
for j, x in enumerate(x_coords):
xy.append((x, y_coords[j]))
res_combos = []
distances = []
for coords in xy:
d, r = get_contact_distances(
coords,
res,
inter_mtx,
folded=True,
)
res_combos.append(r)
distances.append(d)
xy = []
x_coords = inter_x_coords[conf]
y_coords = inter_y_coords[conf]
for j, x in enumerate(x_coords):
xy.append((x, y_coords[j]))

res_combos = []
distances = []
for coords in xy:
d, r = get_contact_distances(
coords,
res,
inter_mtx,
folded=True,
)
res_combos.append(r)
distances.append(d)

# We want to only build fragments of IDP
seq1_id = chains[0]
seq2_id = chains[1]
# If we have repeats of residues where there shouldn't be repeats
removed_idxs = []
# We have two IDPs
if seq1_id[0] == ">" and seq2_id[0] == ">":
pass
# IDP-fld case
elif seq1_id[0] != ">" and seq2_id[0] == ">":
idp_sequences = []
in_seq = input_seq[seq2_id]
for res_pair in res_combos:
idp_seq = ""
idp_res = res_pair[1]
for r in idp_res:
idp_seq += input_seq[seq2_id][r]
idp_seq += in_seq[r]
idp_sequences.append(idp_seq)


seq_counts = count_occurrences(in_seq, set(idp_sequences))

valid_idps = []
used_counts = Counter()

for index, substring in enumerate(idp_sequences):
if used_counts[substring] < seq_counts[substring]:
valid_idps.append(substring)
used_counts[substring] += 1
else:
removed_idxs.append(index)

idp_sequences = valid_idps

SLICEDICT_XMERS = []
XMERPROBS = []
GET_ADJ = []
Expand Down Expand Up @@ -1085,6 +1108,20 @@ def main(
# First element in res contains our residues of interest
coords = find_ca_coords_fld(fld_data, res[0])
fld_contact_coords.append(coords)

# Remove information that was related to repeat sequences
if len(removed_idxs) > 0:
for i in removed_idxs:
distances.pop(i)
fld_contact_coords.pop(i)

# If the number of coords do not match the number of CA distances
# remove the last CA coordinate
for i, d in enumerate(distances):
coords = fld_contact_coords[i]
while len(coords) > len(d):
coords.pop(-1)
fld_contact_coords[i] = coords


def populate_globals(
Expand Down Expand Up @@ -1276,7 +1313,7 @@ def conformer_generator(
PLACE_SIDECHAIN_TEMPLATE = place_sidechain_template
RAD_60 = np.radians(60)
RC = np.random.choice
RINT = randint
RINT = random.randint
ROT_COORDINATES = rotate_coordinates_Q_njit
RRD10 = rrd10_njit
SIDECHAIN_TEMPLATES = sidechain_templates
Expand Down
26 changes: 26 additions & 0 deletions src/idpconfgen/libs/libparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import ast
import subprocess
from collections import Counter
from functools import partial
from itertools import product, repeat
from operator import setitem
Expand Down Expand Up @@ -813,3 +814,28 @@ def adjust_iterable_length(lst, desired_length):
return newlst.join(lst)

return lst


def count_occurrences(main_str, substrings):
"""
Count occurrences of substrings in the main string.
Parameters
----------
main_str : str
Main string, usually an amino acid sequence
substrings : list of str
List of substrings
Returns
-------
counts : int
Number of occurences in main string
"""
counts = Counter()
for s in substrings:
count = main_str.count(s)
counts[s] = count

return counts

0 comments on commit 3123337

Please sign in to comment.