Skip to content

Commit

Permalink
Support nltk>=3.9 to fix vulnerability (#629)
Browse files Browse the repository at this point in the history
* Replace punkt with punkt_tab in word_length

* Replace punkt with punkt_tab in meteor

* Replace punkt with punkt_tab in docs

* Revert temporary pin nltk<3.9

This reverts commit d1a15f6.

* Fix import in word_length

* Fix import in word_length

* Fix style
  • Loading branch information
albertvillanova authored Sep 13, 2024
1 parent eb4dac2 commit b3f3c02
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 6 deletions.
4 changes: 2 additions & 2 deletions docs/source/creating_and_sharing.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def _download_and_prepare(self, dl_manager):
self.scorer = score.BleurtScorer(os.path.join(model_path, self.config_name))
```

Or if you need to download the NLTK `"punkt"` resources:
Or if you need to download the NLTK `"punkt_tab"` resources:

```py
def _download_and_prepare(self, dl_manager):
import nltk
nltk.download("punkt")
nltk.download("punkt_tab")
```

Next, we need to define how the computation of the evaluation module works.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/transformers_integrations.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def preprocess_function(examples):
tokenized_billsum = billsum.map(preprocess_function, batched=True)

# Setup evaluation
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

def compute_metrics(eval_preds):
Expand Down
14 changes: 13 additions & 1 deletion measurements/word_length/word_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@

import datasets
from nltk import word_tokenize
from packaging import version

import evaluate


if evaluate.config.PY_VERSION < version.parse("3.8"):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata


NLTK_VERSION = version.parse(importlib_metadata.version("nltk"))

_DESCRIPTION = """
Returns the average length (in terms of the number of words) of the input data.
"""
Expand Down Expand Up @@ -75,7 +84,10 @@ def _info(self):
def _download_and_prepare(self, dl_manager):
import nltk

nltk.download("punkt")
if NLTK_VERSION >= version.Version("3.9.0"):
nltk.download("punkt_tab")
else:
nltk.download("punkt")

def _compute(self, data, tokenizer=word_tokenize):
"""Returns the average word length of the input data"""
Expand Down
4 changes: 3 additions & 1 deletion metrics/meteor/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def _download_and_prepare(self, dl_manager):
import nltk

nltk.download("wordnet")
if NLTK_VERSION >= version.Version("3.6.5"):
if NLTK_VERSION >= version.Version("3.9.0"):
nltk.download("punkt_tab")
elif NLTK_VERSION >= version.Version("3.6.5"):
nltk.download("punkt")
if NLTK_VERSION >= version.Version("3.6.6"):
nltk.download("omw-1.4")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"absl-py",
"charcut>=1.1.1", # for charcut_mt
"cer>=1.2.0", # for characTER
"nltk<3.9", # for NIST and probably others; temporarily pin < 3.9 to avoid "Resource punkt_tab not found" (GH-622)
"nltk", # for NIST and probably others
"pytest",
"pytest-datadir",
"pytest-xdist",
Expand Down

0 comments on commit b3f3c02

Please sign in to comment.