-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create text.py file containing tfidf
- Loading branch information
Showing
1 changed file
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
from verticapy._typing import SQLRelation | ||
from verticapy._utils._sql._sys import _executeSQL | ||
from verticapy.core.vdataframe.base import vDataFrame | ||
|
||
class Tfidf: | ||
""" | ||
Create tfidf representation of documents. | ||
Parameters | ||
---------- | ||
name: str | ||
Name of the model. | ||
schema: str | ||
Name of the model. | ||
overwrite_model: bool, default=False | ||
If set to True, training a model with the same | ||
name as an existing model overwrites the | ||
existing model. | ||
lowercase: bool, default=True | ||
Converts all the elements to lowercase before | ||
processing. | ||
""" | ||
def __init__( | ||
self, | ||
name: str = None, | ||
schema: str = None, | ||
overwrite_model: bool = False, | ||
lowercase: bool = True | ||
) -> None: | ||
self.parameters = { | ||
"name": name, | ||
"schema": schema, | ||
"overwrite_model": overwrite_model, | ||
"lowercase": lowercase | ||
} | ||
|
||
def fit( | ||
self, | ||
input_relation: SQLRelation, | ||
index: str, | ||
column: str | ||
)-> None: | ||
"""Apply basic pre-processing. Create table with fitted vocabulary and idf values | ||
Parameters | ||
---------- | ||
input_relation: SQLRelation (db object or vdf) | ||
index: str | ||
Column name of the document id | ||
column: str | ||
Column name which contains the text | ||
Returns | ||
---------- | ||
None | ||
""" | ||
|
||
if isinstance(input_relation, vDataFrame): | ||
vdf = input_relation.copy() | ||
else: | ||
vdf = vDataFrame(input_relation) | ||
|
||
if not self.parameters["lowercase"]: | ||
text = {column} | ||
else: | ||
text = f"LOWER({column})" | ||
|
||
self.idf_ = ".".join([self.parameters["schema"],self.parameters["name"]]) | ||
|
||
if self.parameters["overwrite_model"]: | ||
_executeSQL(f"DROP TABLE IF EXISTS {self.idf_}", print_time_sql=False) | ||
|
||
q_idf = f"""CREATE TABLE {self.idf_} AS | ||
WITH | ||
tdc AS ( | ||
SELECT | ||
count({index}) count_docs | ||
FROM {vdf} | ||
), | ||
words_by_post AS ( | ||
SELECT | ||
{index} as row_id | ||
,{text} as content | ||
,string_to_array( | ||
REGEXP_REPLACE( | ||
TRIM( | ||
REGEXP_REPLACE( | ||
REGEXP_REPLACE( | ||
{text},'[^\w ]','' | ||
), | ||
' {2,}',' ' | ||
) | ||
), | ||
'\s',',') | ||
) words | ||
,COUNT(*) OVER() docs_n | ||
FROM {vdf} | ||
), | ||
exploded AS ( | ||
SELECT | ||
EXPLODE(words, words, content, row_id ) OVER(PARTITION best) | ||
FROM words_by_post | ||
) | ||
SELECT | ||
value AS word | ||
, COUNT(DISTINCT row_id) AS word_doc_count | ||
, count_docs | ||
,LN((1+count_docs)/(1+word_doc_count)+1) idf_log | ||
FROM exploded | ||
CROSS JOIN tdc | ||
GROUP BY word,count_docs | ||
ORDER BY word_doc_count desc""" | ||
|
||
_executeSQL(q_idf, print_time_sql=False) | ||
|
||
|
||
def transform( | ||
self, | ||
input_relation: SQLRelation, | ||
index: str, | ||
column: str, | ||
pivot: bool = False | ||
)-> vDataFrame: | ||
|
||
"""Apply basic pre-processing. Create table with vocabulary and idf values | ||
Parameters | ||
---------- | ||
input_relation: SQLRelation (db object or vdf) | ||
index: str | ||
Column name of the document id | ||
column: str | ||
Column name which contains the text | ||
pivot: str, default=False | ||
If True it will Pivot the final table, to have 1 row per document and a sparse matrix. | ||
Returns | ||
---------- | ||
vDataFrame | ||
""" | ||
|
||
if isinstance(input_relation, vDataFrame): | ||
vdf = input_relation.copy() | ||
else: | ||
vdf = vDataFrame(input_relation) | ||
|
||
if not self.parameters["lowercase"]: | ||
text = {column} | ||
else: | ||
text = f"LOWER({column})" | ||
|
||
q_tfidf = f"""WITH | ||
words_by_post AS ( | ||
SELECT | ||
{index} as row_id | ||
,{text} as content | ||
,string_to_array( | ||
REGEXP_REPLACE( | ||
TRIM( | ||
REGEXP_REPLACE( | ||
REGEXP_REPLACE( | ||
{text},'[^\w ]','' | ||
), | ||
' {2,}',' ' | ||
) | ||
), | ||
'\s',',') | ||
) words | ||
,COUNT(*) OVER() docs_n | ||
FROM {vdf} | ||
), | ||
exploded AS ( | ||
SELECT | ||
EXPLODE(words, words, content, row_id ) OVER(PARTITION best) | ||
FROM words_by_post | ||
), | ||
tf AS ( | ||
SELECT | ||
row_id | ||
,value as word | ||
,count(*) as tf | ||
FROM exploded | ||
GROUP BY row_id,word,words | ||
) | ||
SELECT | ||
tf.row_id | ||
,{self.idf_}.word | ||
,tf*idf_log/sqrt(sum((tf*idf_log)^2) OVER(partition by tf.row_id)) tf_idf | ||
FROM tf | ||
INNER JOIN {self.idf_} on tf.word = {self.idf_}.word | ||
ORDER BY tf.row_id""" | ||
|
||
result = vDataFrame(q_tfidf) | ||
if not pivot: | ||
return result | ||
else: | ||
return result.pivot(index="row_id", columns = "word", values = "tf_idf", prefix = "") |