Skip to content

Commit

Permalink
fix: handle multiple number of topics
Browse files Browse the repository at this point in the history
  • Loading branch information
MariellaCC committed Nov 26, 2024
1 parent a29e6a3 commit 3a73ffa
Showing 1 changed file with 80 additions and 47 deletions.
127 changes: 80 additions & 47 deletions src/kiara_plugin/topic_modelling/modules/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def process(self, inputs, outputs):

import gensim # type: ignore
from gensim import corpora # type: ignore
import numpy as np # type: ignore

# check coherence method
coherence = inputs.get_value_data("coherence")
Expand Down Expand Up @@ -274,7 +273,7 @@ def process(self, inputs, outputs):
model = gensim.models.ldamodel.LdaModel(corpus, id2word=id2word, **lda_kwargs)


raw_top_topics = model.top_topics(corpus=corpus, coherence=coherence)
raw_top_topics = model.top_topics(corpus=corpus, coherence=coherence, texts=tokens_list)

# Transform top topics into a more readable format
formatted_topics = []
Expand Down Expand Up @@ -310,7 +309,6 @@ class RunLdaCoherence(KiaraModule):
"""
This module is used to run LDA with model coherence score and to compare model coherences depending on the chosen number of topics.
https://radimrehurek.com/gensim_3.8.3/models/coherencemodel.html
"""

_module_type_name = "topic_modelling.lda_coherence"
Expand All @@ -331,10 +329,10 @@ def create_inputs_schema(self):
"doc": "Remove tokens that appear in more than no_above documents.",
"optional": True,
},
"num_topics": {
"type": "integer",
"doc": "Number of topics to process.",
"optional": True,
"num_topics_list": {
"type": "list",
"doc": "List of integers (maximum 5) specifying the numbers of topics to test. Each integer will generate a separate model.",
"optional": False,
},
"passes": {
"type": "integer",
Expand All @@ -346,21 +344,16 @@ def create_inputs_schema(self):
"doc": "Chunksize.",
"optional": True,
},
"iterations": {
"iterations": {
"type": "integer",
"doc": "Number of iterations.",
"optional": True,
},
"random_state": {
"random_state": {
"type": "integer",
"doc": "Random state.",
"optional": True,
},
"range_of_number_of_topics": {
"type": "list",
"doc": "The range of number of topics to test model coherence. The coherence score will be calculated for each number of topics in the range.",
"optional": False,
},
"coherence": {
"type": "string",
"doc": "Methodology to compute coherence. Possible values are 'c_v', 'u_mass', 'c_uci', 'c_npmi'.",
Expand All @@ -371,40 +364,61 @@ def create_inputs_schema(self):

def create_outputs_schema(self):
return {
"coherence": {
"coherence_scores": {
"type": "list",
"doc": "The 15 most common words overall."
"doc": "List of coherence scores, one for each model."
},
"print_topics": {
"type": "list",
"doc": "A list of the topics per model."
"doc": "List of dictionaries, each containing the number of topics and its corresponding topics."
}
}

def process(self, inputs, outputs):

import gensim # type: ignore
from gensim import corpora # type: ignore
from gensim.models.coherencemodel import CoherenceModel # type: ignore
import numpy as np # type: ignore
from gensim import corpora # type: ignore
from gensim.models.coherencemodel import CoherenceModel # type: ignore

# Get and convert KiaraList to Python list
topics_list = inputs.get_value_data("num_topics_list").list_data

# Validate each element is an integer
if not all(isinstance(x, int) for x in topics_list):
raise KiaraProcessingException(
"All values in num_topics_list must be integers"
)

# Validate list length
if len(topics_list) == 0:
raise KiaraProcessingException(
"num_topics_list cannot be empty"
)
if len(topics_list) > 5:
raise KiaraProcessingException(
"num_topics_list cannot contain more than 5 values"
)

# Validate values are positive
if any(x <= 0 for x in topics_list):
raise KiaraProcessingException(
"All values in num_topics_list must be positive integers"
)

tokens_array = inputs.get_value_data("tokens_array")
tokens_list = tokens_array.arrow_array.to_pylist()

# check coherence method
# Check coherence method
coherence = inputs.get_value_data("coherence")
if coherence not in ["c_v", "u_mass", "c_uci", "c_npmi"]:
raise KiaraProcessingException(f"Invalid coherence method: {coherence}")


# Prepare model parameters
input_values = {
"num_topics": inputs.get_value_data("num_topics"),
"passes": inputs.get_value_data("passes"),
"chunksize": inputs.get_value_data("chunksize"),
"iterations": inputs.get_value_data("iterations"),
"random_state": inputs.get_value_data("random_state")
}


lda_kwargs = {k: v for k, v in input_values.items() if v is not None}

id2word_kwargs = {
Expand All @@ -414,7 +428,6 @@ def process(self, inputs, outputs):
}.items() if v is not None
}


try:
# Create dictionary
id2word = corpora.Dictionary(tokens_list)
Expand All @@ -426,28 +439,48 @@ def process(self, inputs, outputs):
# Create corpus
corpus = [id2word.doc2bow(text) for text in tokens_list]

# Create and train model
model = gensim.models.ldamulticore.LdaMulticore(
corpus,
id2word=id2word,
**lda_kwargs
)

# Get raw topics and convert them
raw_topics = model.print_topics(num_words=30)
topics = []
for idx, topic in raw_topics:
# Convert any numpy types in the index and ensure topic is a string
topics.append((int(idx), str(topic)))

topics.sort(key=lambda x: x[0])

# Convert common words with explicit type conversion
common_words = [(str(word), int(count)) for word, count in id2word.most_common(15)]
# Lists to store results
all_models_topics = []
coherence_scores = []

# Process each number of topics
for num_topics in sorted(topics_list): # Sort for consistent order
# Create and train model
model = gensim.models.ldamulticore.LdaMulticore(
corpus=corpus,
id2word=id2word,
num_topics=num_topics,
**lda_kwargs
)

# Get raw topics and convert them
raw_topics = model.print_topics(num_words=30)
topics = []
for idx, topic in raw_topics:
topics.append((int(idx), str(topic)))
topics.sort(key=lambda x: x[0])

# Calculate coherence
coherencemodel = CoherenceModel(
model=model,
texts=tokens_list,
dictionary=id2word,
coherence=coherence
)
coherence_value = coherencemodel.get_coherence()

# Store results with model information
model_topics = {
"num_topics": num_topics,
"topics": topics,
"coherence": float(coherence_value)
}
all_models_topics.append(model_topics)
coherence_scores.append(float(coherence_value))

# Set outputs
outputs.set_value("print_topics", topics)
outputs.set_value("most_common_words", common_words)
outputs.set_value("print_topics", all_models_topics)
outputs.set_value("coherence_scores", coherence_scores)

except Exception as e:
raise KiaraProcessingException(f"Processing failed: {e}")

0 comments on commit 3a73ffa

Please sign in to comment.