Skip to content

Commit

Permalink
Add support for GLiNER models, closes #862
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Jan 31, 2025
1 parent ac4c058 commit 13bafb0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 2 deletions.
78 changes: 76 additions & 2 deletions src/python/txtai/pipeline/text/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
Entity module
"""

# Conditional import
try:
from gliner import GLiNER

GLINER = True
except ImportError:
GLINER = False

from transformers.utils import cached_file

from ...models import Models
from ..hfpipeline import HFPipeline


Expand All @@ -11,7 +22,18 @@ class Entity(HFPipeline):
"""

def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
super().__init__("token-classification", path, quantize, gpu, model, **kwargs)
# Create a new entity pipeline
self.gliner = self.isgliner(path)
if self.gliner:
if not GLINER:
raise ImportError('GLiNER is not available - install "pipeline" extra to enable')

Check failure on line 29 in src/python/txtai/pipeline/text/entity.py

View workflow job for this annotation

GitHub Actions / build (macos-latest)

GLiNER is not available - install "pipeline" extra to enable

# GLiNER entity pipeline
self.pipeline = GLiNER.from_pretrained(path)
self.pipeline = self.pipeline.to(Models.device(Models.deviceid(gpu)))
else:
# Standard entity pipeline
super().__init__("token-classification", path, quantize, gpu, model, **kwargs)

def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=False, workers=0):
"""
Expand All @@ -30,7 +52,7 @@ def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=Fal
"""

# Run token classification pipeline
results = self.pipeline(text, aggregation_strategy=aggregate, num_workers=workers)
results = self.execute(text, labels, aggregate, workers)

# Convert results to a list if necessary
if isinstance(text, str):
Expand All @@ -50,6 +72,58 @@ def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=Fal

return outputs[0] if isinstance(text, str) else outputs

def isgliner(self, path):
"""
Tests if path is a GLiNER model.

Args:
path: model path

Returns:
True if this is a GLiNER model, False otherwise
"""

try:
# Test if this model has a gliner_config.json file
return cached_file(path_or_repo_id=path, filename="gliner_config.json") is not None

# Ignore this error - invalid repo or directory
except OSError:
pass

return False

def execute(self, text, labels, aggregate, workers):
"""
Runs the entity extraction pipeline.

Args:
text: text|list
labels: list of entity type labels to accept, defaults to None which accepts all
aggregate: method to combine multi token entities - options are "simple" (default), "first", "average" or "max"
workers: number of concurrent workers to use for processing data, defaults to None

Returns:
list of entities and labels
"""

if self.gliner:
# Extract entities with GLiNER. Use default CoNLL-2003 labels when not otherwise provided.
results = self.pipeline.batch_predict_entities(
text if isinstance(text, list) else [text], labels if labels else ["person", "organization", "location"]
)

# Map results to same format as Transformers token classifier
entities = []
for result in results:
entities.append([{"word": x["text"], "entity_group": x["label"], "score": x["score"]} for x in result])

# Return extracted entities
return entities if isinstance(text, list) else entities[0]

# Standard Transformers token classification pipeline
return self.pipeline(text, aggregation_strategy=aggregate, num_workers=workers)

def accept(self, etype, labels):
"""
Determines if entity type is in valid entity type.
Expand Down
5 changes: 5 additions & 0 deletions test/python/testoptional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def setUpClass(cls):
"docling.document_converter",
"duckdb",
"fastapi",
"gliner",
"grand-cypher",
"grand-graph",
"hnswlib",
Expand Down Expand Up @@ -186,6 +187,7 @@ def testPipeline(self):
AudioMixer,
AudioStream,
Caption,
Entity,
FileToHTML,
HFOnnx,
HFTrainer,
Expand Down Expand Up @@ -213,6 +215,9 @@ def testPipeline(self):
with self.assertRaises(ImportError):
Caption()

with self.assertRaises(ImportError):
Entity("neuml/gliner-bert-tiny")

with self.assertRaises(ImportError):
FileToHTML(backend="docling")

Expand Down
9 changes: 9 additions & 0 deletions test/python/testpipeline/testtext/testentity.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,12 @@ def testEntityTypes(self):
# Run entity extraction
entities = self.entity("Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", labels=["PER"])
self.assertFalse(entities)

def testGliner(self):
"""
Test entity pipeline with a GLiNER model
"""

entity = Entity("neuml/gliner-bert-tiny")
entities = entity("My name is John Smith.", flatten=True)
self.assertEqual(entities, ["John Smith"])

0 comments on commit 13bafb0

Please sign in to comment.