Skip to content

Commit

Permalink
Create text.py file containing tfidf
Browse files Browse the repository at this point in the history
  • Loading branch information
mat-shyR committed Nov 8, 2023
1 parent 32a861a commit 837fb8b
Showing 1 changed file with 197 additions and 0 deletions.
197 changes: 197 additions & 0 deletions verticapy/machine_learning/feature_extraction/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from verticapy._typing import SQLRelation
from verticapy._utils._sql._sys import _executeSQL
from verticapy.core.vdataframe.base import vDataFrame

class Tfidf:
"""
Create tfidf representation of documents.
Parameters
----------
name: str
Name of the model.
schema: str
Name of the model.
overwrite_model: bool, default=False
If set to True, training a model with the same
name as an existing model overwrites the
existing model.
lowercase: bool, default=True
Converts all the elements to lowercase before
processing.
"""
def __init__(
self,
name: str = None,
schema: str = None,
overwrite_model: bool = False,
lowercase: bool = True
) -> None:
self.parameters = {
"name": name,
"schema": schema,
"overwrite_model": overwrite_model,
"lowercase": lowercase
}

def fit(
self,
input_relation: SQLRelation,
index: str,
column: str
)-> None:
"""Apply basic pre-processing. Create table with fitted vocabulary and idf values
Parameters
----------
input_relation: SQLRelation (db object or vdf)
index: str
Column name of the document id
column: str
Column name which contains the text
Returns
----------
None
"""

if isinstance(input_relation, vDataFrame):
vdf = input_relation.copy()
else:
vdf = vDataFrame(input_relation)

if not self.parameters["lowercase"]:
text = {column}
else:
text = f"LOWER({column})"

self.idf_ = ".".join([self.parameters["schema"],self.parameters["name"]])

if self.parameters["overwrite_model"]:
_executeSQL(f"DROP TABLE IF EXISTS {self.idf_}", print_time_sql=False)

q_idf = f"""CREATE TABLE {self.idf_} AS
WITH
tdc AS (
SELECT
count({index}) count_docs
FROM {vdf}
),
words_by_post AS (
SELECT
{index} as row_id
,{text} as content
,string_to_array(
REGEXP_REPLACE(
TRIM(
REGEXP_REPLACE(
REGEXP_REPLACE(
{text},'[^\w ]',''
),
' {2,}',' '
)
),
'\s',',')
) words
,COUNT(*) OVER() docs_n
FROM {vdf}
),
exploded AS (
SELECT
EXPLODE(words, words, content, row_id ) OVER(PARTITION best)
FROM words_by_post
)
SELECT
value AS word
, COUNT(DISTINCT row_id) AS word_doc_count
, count_docs
,LN((1+count_docs)/(1+word_doc_count)+1) idf_log
FROM exploded
CROSS JOIN tdc
GROUP BY word,count_docs
ORDER BY word_doc_count desc"""

_executeSQL(q_idf, print_time_sql=False)


def transform(
self,
input_relation: SQLRelation,
index: str,
column: str,
pivot: bool = False
)-> vDataFrame:

"""Apply basic pre-processing. Create table with vocabulary and idf values
Parameters
----------
input_relation: SQLRelation (db object or vdf)
index: str
Column name of the document id
column: str
Column name which contains the text
pivot: str, default=False
If True it will Pivot the final table, to have 1 row per document and a sparse matrix.
Returns
----------
vDataFrame
"""

if isinstance(input_relation, vDataFrame):
vdf = input_relation.copy()
else:
vdf = vDataFrame(input_relation)

if not self.parameters["lowercase"]:
text = {column}
else:
text = f"LOWER({column})"

q_tfidf = f"""WITH
words_by_post AS (
SELECT
{index} as row_id
,{text} as content
,string_to_array(
REGEXP_REPLACE(
TRIM(
REGEXP_REPLACE(
REGEXP_REPLACE(
{text},'[^\w ]',''
),
' {2,}',' '
)
),
'\s',',')
) words
,COUNT(*) OVER() docs_n
FROM {vdf}
),
exploded AS (
SELECT
EXPLODE(words, words, content, row_id ) OVER(PARTITION best)
FROM words_by_post
),
tf AS (
SELECT
row_id
,value as word
,count(*) as tf
FROM exploded
GROUP BY row_id,word,words
)
SELECT
tf.row_id
,{self.idf_}.word
,tf*idf_log/sqrt(sum((tf*idf_log)^2) OVER(partition by tf.row_id)) tf_idf
FROM tf
INNER JOIN {self.idf_} on tf.word = {self.idf_}.word
ORDER BY tf.row_id"""

result = vDataFrame(q_tfidf)
if not pivot:
return result
else:
return result.pivot(index="row_id", columns = "word", values = "tf_idf", prefix = "")

0 comments on commit 837fb8b

Please sign in to comment.