Skip to content

Commit

Permalink
Use pdf entity response for PDF ner
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Jan 17, 2025
1 parent 83a235c commit 18664fb
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 63 deletions.
1 change: 1 addition & 0 deletions src/domain/NamedEntity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def normalize_location(self, text):
def normalize_date(self, text):
if self.normalized_text:
return self.normalized_text

parsers = [parser for parser in default_parsers if parser != "relative-time"]
settings = {"STRICT_PARSING": True, "PARSERS": parsers}
return dateparser.parse(text).strftime("%Y-%m-%d") if search_dates(self.text, settings=settings) else self.text
Expand Down
4 changes: 3 additions & 1 deletion src/domain/NamedEntityGroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from domain.NamedEntityType import NamedEntityType
from rapidfuzz import fuzz

from domain.PDFNamedEntity import PDFNamedEntity


class NamedEntityGroup(BaseModel):
type: NamedEntityType
name: str
named_entities: list[NamedEntity] = list()
named_entities: list[NamedEntity | PDFNamedEntity] = list()

def is_same_type(self, named_entity: NamedEntity) -> bool:
return self.type == named_entity.type
Expand Down
9 changes: 8 additions & 1 deletion src/drivers/rest/GroupResponse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pydantic import BaseModel

from domain.NamedEntityGroup import NamedEntityGroup
from domain.NamedEntityType import NamedEntityType
from drivers.rest.NamedEntityResponse import NamedEntityResponse


class GroupResponse(BaseModel):
group_name: str
type: NamedEntityType
entities_ids: list[int]
entities_text: list[str]

Expand All @@ -19,4 +21,9 @@ def from_named_entity_group(named_entity_group: NamedEntityGroup, entities: list
entity_indexes.append(index)
entity_texts.append(entity.text)

return GroupResponse(group_name=named_entity_group.name, entities_ids=entity_indexes, entities_text=entity_texts)
return GroupResponse(
group_name=named_entity_group.name,
type=entities[entity_indexes[0]].type if entity_indexes else NamedEntityType.PERSON,
entities_ids=entity_indexes,
entities_text=entity_texts,
)
47 changes: 36 additions & 11 deletions src/drivers/rest/NamedEntitiesResponse.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,48 @@
from pydantic import BaseModel

from domain.NamedEntity import NamedEntity
from domain.NamedEntityGroup import NamedEntityGroup
from domain.PDFNamedEntity import PDFNamedEntity
from drivers.rest.NamedEntityResponse import NamedEntityResponse
from drivers.rest.GroupResponse import GroupResponse
from drivers.rest.PDFNamedEntityResponse import PDFNamedEntityResponse


class NamedEntitiesResponse(BaseModel):
entities: list[NamedEntityResponse]
entities: list[NamedEntityResponse | PDFNamedEntityResponse]
groups: list[GroupResponse]

def sort_entities(self):
if not self.entities:
return

if isinstance(self.entities[0], PDFNamedEntityResponse):
self.entities.sort(key=lambda x: (x.segment.page_number, x.segment.segment_number, x.segment.character_start))
else:
self.entities.sort(key=lambda x: x.character_start)

def add_entity(self, entity: NamedEntity | PDFNamedEntity, group_name: str):
entity_response = (
PDFNamedEntityResponse.from_pdf_named_entity(entity, group_name)
if isinstance(entity, PDFNamedEntity)
else NamedEntityResponse.from_named_entity(entity, group_name)
)
self.entities.append(entity_response)

def add_group(self, named_entity_group: NamedEntityGroup):
self.groups.append(GroupResponse.from_named_entity_group(named_entity_group, self.entities))

@staticmethod
def from_named_entity_groups(named_entity_groups: list[NamedEntityGroup]):
named_entity_responses = [
NamedEntityResponse.from_named_entity(entity, group.name)
for group in named_entity_groups
for entity in group.named_entities
]
named_entity_responses = sorted(named_entity_responses, key=lambda x: x.character_start)
group_responses = [
GroupResponse.from_named_entity_group(group, named_entity_responses) for group in named_entity_groups
]
return NamedEntitiesResponse(entities=named_entity_responses, groups=group_responses)
named_entities_response = NamedEntitiesResponse(entities=list(), groups=list())

for group in named_entity_groups:
for entity in group.named_entities:
named_entities_response.add_entity(entity, group.name)

named_entities_response.sort_entities()

for group in named_entity_groups:
named_entities_response.add_group(group)

return named_entities_response
22 changes: 0 additions & 22 deletions src/drivers/rest/PDFNamedEntitiesResponse.py

This file was deleted.

11 changes: 7 additions & 4 deletions src/drivers/rest/PDFNamedEntityResponse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from pydantic import BaseModel

from domain.NamedEntityType import NamedEntityType
from domain.PDFNamedEntity import PDFNamedEntity
from drivers.rest.NamedEntityResponse import NamedEntityResponse
from drivers.rest.SegmentResponse import SegmentResponse


class PDFNamedEntityResponse(NamedEntityResponse):
class PDFNamedEntityResponse(BaseModel):
group_name: str
type: NamedEntityType
text: str
page_number: int
segment: SegmentResponse

Expand All @@ -13,8 +18,6 @@ def from_pdf_named_entity(pdf_named_entity: PDFNamedEntity, group_name: str):
group_name=group_name,
type=pdf_named_entity.type,
text=pdf_named_entity.text,
character_start=pdf_named_entity.character_start,
character_end=pdf_named_entity.character_end,
page_number=pdf_named_entity.segment.page_number,
segment=SegmentResponse.from_pdf_named_entity(pdf_named_entity),
)
5 changes: 2 additions & 3 deletions src/drivers/rest/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from domain.NamedEntity import NamedEntity
from domain.NamedEntityGroup import NamedEntityGroup
from drivers.rest.NamedEntitiesResponse import NamedEntitiesResponse
from drivers.rest.PDFNamedEntitiesResponse import PDFNamedEntitiesResponse
from use_cases.NamedEntitiesFromPDFUseCase import NamedEntitiesFromPDFUseCase
from use_cases.NamedEntitiesFromTextUseCase import NamedEntitiesFromTextUseCase
from use_cases.NamedEntityMergerUseCase import NamedEntityMergerUseCase
Expand Down Expand Up @@ -43,6 +42,6 @@ async def get_named_entities(text: str = Form("")):
async def get_pdf_named_entities(file: UploadFile = File(...)):
pdf_path: Path = pdf_content_to_pdf_path(file.file.read())
pdf_layout_analysis_repository = PDFLayoutAnalysisRepository()
entities = [entity for entity in NamedEntitiesFromPDFUseCase(pdf_layout_analysis_repository).get_entities(pdf_path)]
entities = NamedEntitiesFromPDFUseCase(pdf_layout_analysis_repository).get_entities(pdf_path)
named_entity_groups: list[NamedEntityGroup] = NamedEntityMergerUseCase().merge(entities)
return PDFNamedEntitiesResponse.from_named_entity_groups(named_entity_groups)
return NamedEntitiesResponse.from_named_entity_groups(named_entity_groups)
92 changes: 71 additions & 21 deletions src/tests/end_to_end/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pathlib import Path
from unittest import TestCase
import requests

from configuration import ROOT_PATH
from drivers.rest.GroupResponse import GroupResponse
from drivers.rest.NamedEntityResponse import NamedEntityResponse
from drivers.rest.PDFNamedEntityResponse import PDFNamedEntityResponse


class TestEndToEnd(TestCase):
Expand Down Expand Up @@ -59,12 +62,14 @@ def test_text_extraction(self):
self.assertEqual(5, len(groups_dict))

self.assertEqual("Tokyo", GroupResponse(**groups_dict[0]).group_name)
self.assertEqual("LOCATION", GroupResponse(**groups_dict[0]).type)
self.assertEqual(1, len(GroupResponse(**groups_dict[0]).entities_ids))
self.assertEqual(1, len(GroupResponse(**groups_dict[0]).entities_text))
self.assertEqual("Tokyo", GroupResponse(**groups_dict[0]).entities_text[0])
self.assertEqual(0, GroupResponse(**groups_dict[0]).entities_ids[0])

self.assertEqual("2025-06-12", GroupResponse(**groups_dict[1]).group_name)
self.assertEqual("DATE", GroupResponse(**groups_dict[1]).type)
self.assertEqual(2, len(GroupResponse(**groups_dict[1]).entities_ids))
self.assertEqual(2, len(GroupResponse(**groups_dict[1]).entities_text))
self.assertEqual("12 June 2025", GroupResponse(**groups_dict[1]).entities_text[0])
Expand All @@ -73,18 +78,21 @@ def test_text_extraction(self):
self.assertEqual(5, GroupResponse(**groups_dict[1]).entities_ids[1])

self.assertEqual("Maria Rodriguez", GroupResponse(**groups_dict[2]).group_name)
self.assertEqual("PERSON", GroupResponse(**groups_dict[2]).type)
self.assertEqual(1, len(GroupResponse(**groups_dict[2]).entities_ids))
self.assertEqual(1, len(GroupResponse(**groups_dict[2]).entities_text))
self.assertEqual("Maria Rodriguez", GroupResponse(**groups_dict[2]).entities_text[0])
self.assertEqual(2, GroupResponse(**groups_dict[2]).entities_ids[0])

self.assertEqual("Senate", GroupResponse(**groups_dict[3]).group_name)
self.assertEqual("ORGANIZATION", GroupResponse(**groups_dict[3]).type)
self.assertEqual(1, len(GroupResponse(**groups_dict[3]).entities_ids))
self.assertEqual(1, len(GroupResponse(**groups_dict[3]).entities_text))
self.assertEqual("Senate", GroupResponse(**groups_dict[3]).entities_text[0])
self.assertEqual(3, GroupResponse(**groups_dict[3]).entities_ids[0])

self.assertEqual("Resolution No. 122", GroupResponse(**groups_dict[4]).group_name)
self.assertEqual("LAW", GroupResponse(**groups_dict[4]).type)
self.assertEqual(1, len(GroupResponse(**groups_dict[4]).entities_ids))
self.assertEqual(1, len(GroupResponse(**groups_dict[4]).entities_text))
self.assertEqual("Resolution No. 122", GroupResponse(**groups_dict[4]).entities_text[0])
Expand Down Expand Up @@ -130,6 +138,7 @@ def test_text_extraction_for_dates(self):
self.assertEqual(2, len(groups_dict))

self.assertEqual("2024-01-13", group_1.group_name)
self.assertEqual("DATE", group_1.type)
self.assertEqual(2, len(group_1.entities_ids))
self.assertEqual(2, len(group_1.entities_text))
self.assertEqual("13th of January 2024", group_1.entities_text[0])
Expand All @@ -138,29 +147,70 @@ def test_text_extraction_for_dates(self):
self.assertEqual(2, group_1.entities_ids[1])

self.assertEqual("13th of February", group_2.group_name)
self.assertEqual("DATE", group_2.type)
self.assertEqual(1, len(group_2.entities_ids))
self.assertEqual(1, len(group_2.entities_text))
self.assertEqual("13th of February", group_2.entities_text[0])
self.assertEqual(1, group_2.entities_ids[0])

# def test_pdf_extraction(self):
# pdf_path: Path = Path(ROOT_PATH, "src", "tests", "end_to_end", "test_pdfs", "test_document.pdf")
# with open(pdf_path, "rb") as pdf_file:
# files = {"file": pdf_file}
# response = requests.post(f"{self.service_url}/pdf", files=files)
# response_json = response.json()
# self.assertEqual(200, response.status_code)
# self.assertEqual(10, len(response_json))
# self.assertEqual("PERSON", response_json[0]["type"])
# self.assertEqual("Maria Rodriguez", response_json[0]["text"])
# self.assertEqual("Maria Rodriguez", response_json[0]["normalized_text"])
# self.assertEqual(0, response_json[0]["character_start"])
# self.assertEqual(15, response_json[0]["character_end"])
# expected_segment_text: str = (
# "Maria Rodriguez visited the Louvre Museum in Paris, France, on Wednesday, July 12, 2023."
# )
# self.assertEqual(expected_segment_text, response_json[0]["segment_text"])
# self.assertEqual(1, response_json[0]["page_number"])
# self.assertEqual(1, response_json[0]["segment_number"])
# self.assertEqual(72, response_json[0]["bounding_box"]["left"])
# self.assertEqual(74, response_json[0]["bounding_box"]["top"])
def test_pdf_extraction(self):
pdf_path: Path = Path(ROOT_PATH, "src", "tests", "end_to_end", "test_pdfs", "test_document.pdf")
with open(pdf_path, "rb") as pdf_file:
files = {"file": pdf_file}
response = requests.post(f"{self.service_url}/pdf", files=files)

entities_dict = response.json()["entities"]
groups_dict = response.json()["groups"]

entity_0 = PDFNamedEntityResponse(**entities_dict[0])
entity_9 = PDFNamedEntityResponse(**entities_dict[9])
group_0 = GroupResponse(**groups_dict[0])
group_7 = GroupResponse(**groups_dict[7])

self.assertEqual(200, response.status_code)

self.assertEqual(10, len(entities_dict))

self.assertEqual("Maria Diaz Rodriguez", entity_0.group_name)
self.assertEqual("PERSON", entity_0.type)
self.assertEqual("Maria Rodriguez", entity_0.text)
segment_text: str = "Maria Rodriguez visited the Louvre Museum in Paris, France, on Wednesday, July 12, 2023."
self.assertEqual(segment_text, entity_0.segment.text)
self.assertEqual(1, entity_0.segment.page_number)
self.assertEqual(1, entity_0.segment.segment_number)
self.assertEqual(0, entity_0.segment.character_start)
self.assertEqual(15, entity_0.segment.character_end)
self.assertEqual(72, entity_0.segment.bounding_box.left)
self.assertEqual(74, entity_0.segment.bounding_box.top)
self.assertEqual(430, entity_0.segment.bounding_box.width)
self.assertEqual(34, entity_0.segment.bounding_box.height)

self.assertEqual("Resolution No. 122", entity_9.text)
self.assertEqual("LAW", entity_9.type)
self.assertEqual("Resolution No. 122", entity_9.group_name)
segment_text: str = "The Senate passed Resolution No. 122, establishing a set of rules for the impeachment trial."
self.assertEqual(segment_text, entity_9.segment.text)
self.assertEqual(1, entity_9.segment.page_number)
self.assertEqual(5, entity_9.segment.segment_number)
self.assertEqual(18, entity_9.segment.character_start)
self.assertEqual(36, entity_9.segment.character_end)
self.assertEqual(72, entity_9.segment.bounding_box.left)
self.assertEqual(351, entity_9.segment.bounding_box.top)
self.assertEqual(440, entity_9.segment.bounding_box.width)
self.assertEqual(35, entity_9.segment.bounding_box.height)

self.assertEqual(8, len(groups_dict))

self.assertEqual("Maria Diaz Rodriguez", group_0.group_name)
self.assertEqual("PERSON", group_0.type)
self.assertEqual(3, len(group_0.entities_ids))
self.assertEqual(3, len(group_0.entities_text))
self.assertEqual([0, 5, 6], group_0.entities_ids)
self.assertEqual(["Maria Rodriguez", "Maria Diaz Rodriguez", "M.D. Rodriguez"], group_0.entities_text)

self.assertEqual("Resolution No. 122", group_7.group_name)
self.assertEqual("LAW", group_7.type)
self.assertEqual(1, len(group_7.entities_ids))
self.assertEqual(1, len(group_7.entities_text))
self.assertEqual(["Resolution No. 122"], group_7.entities_text)
self.assertEqual([9], group_7.entities_ids)

0 comments on commit 18664fb

Please sign in to comment.