Skip to content

Commit

Permalink
Feature clusterer import method (#76)
Browse files Browse the repository at this point in the history
* Add `import_clusterer` method

* Add test for new clusterer import

* Add docstring

* Fix "if main" of test file

* Addition to gitignore
  • Loading branch information
sbaldu authored Jan 13, 2025
1 parent 9f80fa3 commit 4d5eee0
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ vgcore.*
*file2.csv
*_output.csv
*_passed.py
*test_sissa_import.csv
63 changes: 53 additions & 10 deletions CLUEstering/CLUEstering.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def run_clue(self,
cluster_ids = np.array(cluster_id_is_seed[0])
is_seed = np.array(cluster_id_is_seed[1])
clusters = np.unique(cluster_ids)
n_seeds = np.sum([1 for i in clusters if i > -1])
n_seeds = np.sum(is_seed)
n_clusters = len(clusters)

cluster_points = [[] for _ in range(n_clusters)]
Expand Down Expand Up @@ -1185,24 +1185,67 @@ def to_csv(self, output_folder: str, file_name: str) -> None:
data = {}
for i in range(self.clust_data.n_dim):
data['x' + str(i)] = self.clust_data.coords.T[i]
data['weight'] = self.clust_data.weight
data['cluster_ids'] = self.clust_prop.cluster_ids
data['is_seed'] = self.clust_prop.is_seed

df_ = pd.DataFrame(data)
df_.to_csv(out_path,index=False)

def import_clusterer(self, input_folder: str, file_name: str) -> None:
"""
Imports the results of a previous clustering.
Parameters
----------
input_folder : string
Full path to the folder containing the file.
file_name : string
Name of the file, with the '.csv' suffix.
Modified attributes
-------------------
clust_data : clustering_data
Properties of the input data.
clust_prop : cluster_properties
Properties of the clusters found.
Returns
-------
None
"""

in_path = input_folder + file_name
df_ = pd.read_csv(in_path, dtype=float)
cluster_ids = np.asarray(df_["cluster_ids"], dtype=int)
is_seed = np.array(df_["is_seed"], dtype=int)

self._handle_dataframe(df_.iloc[:, :-2])

clusters = np.unique(cluster_ids)
n_seeds = np.sum(is_seed)
n_clusters = len(clusters)

cluster_points = [[] for _ in range(n_clusters)]
for i in range(self.clust_data.n_points):
cluster_points[cluster_ids[i]].append(i)

points_per_cluster = np.array([len(clust) for clust in cluster_points])
self.clust_prop = cluster_properties(n_clusters,
n_seeds,
clusters,
cluster_ids,
is_seed,
np.asarray(cluster_points, dtype=object),
points_per_cluster,
df_)

if __name__ == "__main__":
c = clusterer(0.8, 5, 1.)
c.read_data('./blob.csv')
c = clusterer(20., 10., 20.)
c.read_data('./sissa.csv')
c.input_plotter()
c.run_clue(backend="cpu serial", verbose=True)
# c.run_clue(backend="cpu tbb", verbose=True)
c.run_clue(backend="cpu tbb", verbose=True)
# c.run_clue(backend="gpu cuda", verbose=True)
# c.run_clue(backend="gpu hip", verbose=True)
c.cluster_plotter()
# c.to_csv('./','sissa_output_tbb.csv')
c.list_devices('cpu serial')
c.list_devices('cpu tbb')
c.list_devices('gpu cuda')
c.list_devices('gpu hip')
c.list_devices()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup
import subprocess

__version__ = "2.3.2"
__version__ = "2.3.2.1"

this_directory = Path(__file__).parent
long_description = (this_directory/'README.md').read_text()
Expand Down
47 changes: 47 additions & 0 deletions tests/test_clusterer_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
'''
Test the import of a clusterer from csv file
'''

from check_result import check_result
import os
import sys
import pandas as pd
import pytest
sys.path.insert(1, '../CLUEstering/')
import CLUEstering as clue

@pytest.fixture
def sissa():
'''
Returns the dataframe containing the sissa dataset
'''
return pd.read_csv("./test_datasets/sissa.csv")

def test_clusterer_import(sissa):
'''
Try importing a clusterer from csv file and check that it's equal to the original clusterer
'''
# Check if the output file already exists and if it does, delete it
if os.path.isfile('./test_sissa_import.csv'):
os.remove('./test_sissa_import.csv')

c = clue.clusterer(20., 10., 20.)
c.read_data(sissa)
c.run_clue()
c.to_csv('./', 'test_sissa_import.csv')

d = clue.clusterer(20., 10., 20.)
d.import_clusterer('./', 'test_sissa_import.csv')

assert c.clust_prop == d.clust_prop

if __name__ == "__main__":
c = clue.clusterer(20., 10., 20.)
c.read_data("./test_datasets/sissa.csv")
c.run_clue()
c.cluster_plotter()
c.to_csv('./', 'test_sissa_import.csv')

d = clue.clusterer(20., 10., 20.)
d.import_clusterer('./', 'test_sissa_import.csv')
d.cluster_plotter()

0 comments on commit 4d5eee0

Please sign in to comment.