From 6b644aa1fb47e0ed380f9068815af27e091db078 Mon Sep 17 00:00:00 2001 From: James Krieger Date: Fri, 20 Oct 2023 15:17:01 +0100 Subject: [PATCH] add threadpoolctl to clustenm app --- prody/apps/prody_apps/prody_clustenm.py | 48 ++++++++++++++++++------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/prody/apps/prody_apps/prody_clustenm.py b/prody/apps/prody_apps/prody_clustenm.py index 8bb304090..f3dfd7faf 100644 --- a/prody/apps/prody_apps/prody_clustenm.py +++ b/prody/apps/prody_apps/prody_clustenm.py @@ -141,19 +141,41 @@ def prody_clustenm(pdb, **kwargs): except TypeError: raise TypeError("Please provide cutoff as a float or equation using math") - ens = prody.ClustENM(pdb.getTitle()) - ens.setAtoms(select) - ens.run(n_gens=ngens, n_modes=nmodes, - n_confs=nconfs, rmsd=eval(rmsd), - cutoff=cutoff, gamma=gamma, - maxclust=eval(maxclust), threshold=eval(threshold), - solvent=solvent, force_field=eval(forcefield), - sim=sim, temp=temp, t_steps_i=t_steps_i, - t_steps_g=eval(t_steps_g), - outlier=outlier, mzscore=mzscore, - sparse=sparse, kdtree=kdtree, turbo=turbo, - parallel=parallel, fitmap=fitmap, - fit_resolution=fit_resolution, **kwargs) + nproc = kwargs.get('nproc') + if nproc: + try: + from threadpoolctl import threadpool_limits + except ImportError: + raise ImportError('Please install threadpoolctl to control threads') + + with threadpool_limits(limits=nproc, user_api="blas"): + ens = prody.ClustENM(pdb.getTitle()) + ens.setAtoms(select) + ens.run(n_gens=ngens, n_modes=nmodes, + n_confs=nconfs, rmsd=eval(rmsd), + cutoff=cutoff, gamma=gamma, + maxclust=eval(maxclust), threshold=eval(threshold), + solvent=solvent, force_field=eval(forcefield), + sim=sim, temp=temp, t_steps_i=t_steps_i, + t_steps_g=eval(t_steps_g), + outlier=outlier, mzscore=mzscore, + sparse=sparse, kdtree=kdtree, turbo=turbo, + parallel=parallel, fitmap=fitmap, + fit_resolution=fit_resolution, **kwargs) + else: + ens = prody.ClustENM(pdb.getTitle()) + ens.setAtoms(select) + ens.run(n_gens=ngens, n_modes=nmodes, + n_confs=nconfs, rmsd=eval(rmsd), + cutoff=cutoff, gamma=gamma, + maxclust=eval(maxclust), threshold=eval(threshold), + solvent=solvent, force_field=eval(forcefield), + sim=sim, temp=temp, t_steps_i=t_steps_i, + t_steps_g=eval(t_steps_g), + outlier=outlier, mzscore=mzscore, + sparse=sparse, kdtree=kdtree, turbo=turbo, + parallel=parallel, fitmap=fitmap, + fit_resolution=fit_resolution, **kwargs) single = not kwargs.pop('multiple') outname = join(outdir, prefix)