From 18664fb749a741a97b544f28696e59f2d11a6809 Mon Sep 17 00:00:00 2001 From: Gabo Date: Fri, 17 Jan 2025 19:06:18 +0100 Subject: [PATCH] Use pdf entity response for PDF ner --- src/domain/NamedEntity.py | 1 + src/domain/NamedEntityGroup.py | 4 +- src/drivers/rest/GroupResponse.py | 9 +- src/drivers/rest/NamedEntitiesResponse.py | 47 +++++++--- src/drivers/rest/PDFNamedEntitiesResponse.py | 22 ----- src/drivers/rest/PDFNamedEntityResponse.py | 11 ++- src/drivers/rest/app.py | 5 +- src/tests/end_to_end/test_end_to_end.py | 92 +++++++++++++++----- 8 files changed, 128 insertions(+), 63 deletions(-) delete mode 100644 src/drivers/rest/PDFNamedEntitiesResponse.py diff --git a/src/domain/NamedEntity.py b/src/domain/NamedEntity.py index d4a3fc6..0421c47 100644 --- a/src/domain/NamedEntity.py +++ b/src/domain/NamedEntity.py @@ -29,6 +29,7 @@ def normalize_location(self, text): def normalize_date(self, text): if self.normalized_text: return self.normalized_text + parsers = [parser for parser in default_parsers if parser != "relative-time"] settings = {"STRICT_PARSING": True, "PARSERS": parsers} return dateparser.parse(text).strftime("%Y-%m-%d") if search_dates(self.text, settings=settings) else self.text diff --git a/src/domain/NamedEntityGroup.py b/src/domain/NamedEntityGroup.py index a50d492..0fefc15 100644 --- a/src/domain/NamedEntityGroup.py +++ b/src/domain/NamedEntityGroup.py @@ -4,11 +4,13 @@ from domain.NamedEntityType import NamedEntityType from rapidfuzz import fuzz +from domain.PDFNamedEntity import PDFNamedEntity + class NamedEntityGroup(BaseModel): type: NamedEntityType name: str - named_entities: list[NamedEntity] = list() + named_entities: list[NamedEntity | PDFNamedEntity] = list() def is_same_type(self, named_entity: NamedEntity) -> bool: return self.type == named_entity.type diff --git a/src/drivers/rest/GroupResponse.py b/src/drivers/rest/GroupResponse.py index 1d45e86..e7bdec8 100644 --- a/src/drivers/rest/GroupResponse.py +++ b/src/drivers/rest/GroupResponse.py @@ -1,11 +1,13 @@ from pydantic import BaseModel from domain.NamedEntityGroup import NamedEntityGroup +from domain.NamedEntityType import NamedEntityType from drivers.rest.NamedEntityResponse import NamedEntityResponse class GroupResponse(BaseModel): group_name: str + type: NamedEntityType entities_ids: list[int] entities_text: list[str] @@ -19,4 +21,9 @@ def from_named_entity_group(named_entity_group: NamedEntityGroup, entities: list entity_indexes.append(index) entity_texts.append(entity.text) - return GroupResponse(group_name=named_entity_group.name, entities_ids=entity_indexes, entities_text=entity_texts) + return GroupResponse( + group_name=named_entity_group.name, + type=entities[entity_indexes[0]].type if entity_indexes else NamedEntityType.PERSON, + entities_ids=entity_indexes, + entities_text=entity_texts, + ) diff --git a/src/drivers/rest/NamedEntitiesResponse.py b/src/drivers/rest/NamedEntitiesResponse.py index 4724795..6465a15 100644 --- a/src/drivers/rest/NamedEntitiesResponse.py +++ b/src/drivers/rest/NamedEntitiesResponse.py @@ -1,23 +1,48 @@ from pydantic import BaseModel +from domain.NamedEntity import NamedEntity from domain.NamedEntityGroup import NamedEntityGroup +from domain.PDFNamedEntity import PDFNamedEntity from drivers.rest.NamedEntityResponse import NamedEntityResponse from drivers.rest.GroupResponse import GroupResponse +from drivers.rest.PDFNamedEntityResponse import PDFNamedEntityResponse class NamedEntitiesResponse(BaseModel): - entities: list[NamedEntityResponse] + entities: list[NamedEntityResponse | PDFNamedEntityResponse] groups: list[GroupResponse] + def sort_entities(self): + if not self.entities: + return + + if isinstance(self.entities[0], PDFNamedEntityResponse): + self.entities.sort(key=lambda x: (x.segment.page_number, x.segment.segment_number, x.segment.character_start)) + else: + self.entities.sort(key=lambda x: x.character_start) + + def add_entity(self, entity: NamedEntity | PDFNamedEntity, group_name: str): + entity_response = ( + PDFNamedEntityResponse.from_pdf_named_entity(entity, group_name) + if isinstance(entity, PDFNamedEntity) + else NamedEntityResponse.from_named_entity(entity, group_name) + ) + self.entities.append(entity_response) + + def add_group(self, named_entity_group: NamedEntityGroup): + self.groups.append(GroupResponse.from_named_entity_group(named_entity_group, self.entities)) + @staticmethod def from_named_entity_groups(named_entity_groups: list[NamedEntityGroup]): - named_entity_responses = [ - NamedEntityResponse.from_named_entity(entity, group.name) - for group in named_entity_groups - for entity in group.named_entities - ] - named_entity_responses = sorted(named_entity_responses, key=lambda x: x.character_start) - group_responses = [ - GroupResponse.from_named_entity_group(group, named_entity_responses) for group in named_entity_groups - ] - return NamedEntitiesResponse(entities=named_entity_responses, groups=group_responses) + named_entities_response = NamedEntitiesResponse(entities=list(), groups=list()) + + for group in named_entity_groups: + for entity in group.named_entities: + named_entities_response.add_entity(entity, group.name) + + named_entities_response.sort_entities() + + for group in named_entity_groups: + named_entities_response.add_group(group) + + return named_entities_response diff --git a/src/drivers/rest/PDFNamedEntitiesResponse.py b/src/drivers/rest/PDFNamedEntitiesResponse.py deleted file mode 100644 index 9ef9028..0000000 --- a/src/drivers/rest/PDFNamedEntitiesResponse.py +++ /dev/null @@ -1,22 +0,0 @@ -from pydantic import BaseModel -from domain.NamedEntityGroup import NamedEntityGroup -from drivers.rest.GroupResponse import GroupResponse -from drivers.rest.PDFNamedEntityResponse import PDFNamedEntityResponse - - -class PDFNamedEntitiesResponse(BaseModel): - entities: list[PDFNamedEntityResponse] - groups: list[GroupResponse] - - @staticmethod - def from_named_entity_groups(named_entity_groups: list[NamedEntityGroup]): - pdf_named_entity_responses = [ - PDFNamedEntityResponse.from_named_entity(entity, group.name) - for group in named_entity_groups - for entity in group.named_entities - ] - pdf_named_entity_responses = sorted(pdf_named_entity_responses, key=lambda x: x.character_start) - group_responses = [ - GroupResponse.from_named_entity_group(group, pdf_named_entity_responses) for group in named_entity_groups - ] - return PDFNamedEntitiesResponse(entities=pdf_named_entity_responses, groups=group_responses) diff --git a/src/drivers/rest/PDFNamedEntityResponse.py b/src/drivers/rest/PDFNamedEntityResponse.py index 0301e35..6a69338 100644 --- a/src/drivers/rest/PDFNamedEntityResponse.py +++ b/src/drivers/rest/PDFNamedEntityResponse.py @@ -1,9 +1,14 @@ +from pydantic import BaseModel + +from domain.NamedEntityType import NamedEntityType from domain.PDFNamedEntity import PDFNamedEntity -from drivers.rest.NamedEntityResponse import NamedEntityResponse from drivers.rest.SegmentResponse import SegmentResponse -class PDFNamedEntityResponse(NamedEntityResponse): +class PDFNamedEntityResponse(BaseModel): + group_name: str + type: NamedEntityType + text: str page_number: int segment: SegmentResponse @@ -13,8 +18,6 @@ def from_pdf_named_entity(pdf_named_entity: PDFNamedEntity, group_name: str): group_name=group_name, type=pdf_named_entity.type, text=pdf_named_entity.text, - character_start=pdf_named_entity.character_start, - character_end=pdf_named_entity.character_end, page_number=pdf_named_entity.segment.page_number, segment=SegmentResponse.from_pdf_named_entity(pdf_named_entity), ) diff --git a/src/drivers/rest/app.py b/src/drivers/rest/app.py index 440101c..c456007 100644 --- a/src/drivers/rest/app.py +++ b/src/drivers/rest/app.py @@ -7,7 +7,6 @@ from domain.NamedEntity import NamedEntity from domain.NamedEntityGroup import NamedEntityGroup from drivers.rest.NamedEntitiesResponse import NamedEntitiesResponse -from drivers.rest.PDFNamedEntitiesResponse import PDFNamedEntitiesResponse from use_cases.NamedEntitiesFromPDFUseCase import NamedEntitiesFromPDFUseCase from use_cases.NamedEntitiesFromTextUseCase import NamedEntitiesFromTextUseCase from use_cases.NamedEntityMergerUseCase import NamedEntityMergerUseCase @@ -43,6 +42,6 @@ async def get_named_entities(text: str = Form("")): async def get_pdf_named_entities(file: UploadFile = File(...)): pdf_path: Path = pdf_content_to_pdf_path(file.file.read()) pdf_layout_analysis_repository = PDFLayoutAnalysisRepository() - entities = [entity for entity in NamedEntitiesFromPDFUseCase(pdf_layout_analysis_repository).get_entities(pdf_path)] + entities = NamedEntitiesFromPDFUseCase(pdf_layout_analysis_repository).get_entities(pdf_path) named_entity_groups: list[NamedEntityGroup] = NamedEntityMergerUseCase().merge(entities) - return PDFNamedEntitiesResponse.from_named_entity_groups(named_entity_groups) + return NamedEntitiesResponse.from_named_entity_groups(named_entity_groups) diff --git a/src/tests/end_to_end/test_end_to_end.py b/src/tests/end_to_end/test_end_to_end.py index dac0ee8..f553f0e 100644 --- a/src/tests/end_to_end/test_end_to_end.py +++ b/src/tests/end_to_end/test_end_to_end.py @@ -1,8 +1,11 @@ +from pathlib import Path from unittest import TestCase import requests +from configuration import ROOT_PATH from drivers.rest.GroupResponse import GroupResponse from drivers.rest.NamedEntityResponse import NamedEntityResponse +from drivers.rest.PDFNamedEntityResponse import PDFNamedEntityResponse class TestEndToEnd(TestCase): @@ -59,12 +62,14 @@ def test_text_extraction(self): self.assertEqual(5, len(groups_dict)) self.assertEqual("Tokyo", GroupResponse(**groups_dict[0]).group_name) + self.assertEqual("LOCATION", GroupResponse(**groups_dict[0]).type) self.assertEqual(1, len(GroupResponse(**groups_dict[0]).entities_ids)) self.assertEqual(1, len(GroupResponse(**groups_dict[0]).entities_text)) self.assertEqual("Tokyo", GroupResponse(**groups_dict[0]).entities_text[0]) self.assertEqual(0, GroupResponse(**groups_dict[0]).entities_ids[0]) self.assertEqual("2025-06-12", GroupResponse(**groups_dict[1]).group_name) + self.assertEqual("DATE", GroupResponse(**groups_dict[1]).type) self.assertEqual(2, len(GroupResponse(**groups_dict[1]).entities_ids)) self.assertEqual(2, len(GroupResponse(**groups_dict[1]).entities_text)) self.assertEqual("12 June 2025", GroupResponse(**groups_dict[1]).entities_text[0]) @@ -73,18 +78,21 @@ def test_text_extraction(self): self.assertEqual(5, GroupResponse(**groups_dict[1]).entities_ids[1]) self.assertEqual("Maria Rodriguez", GroupResponse(**groups_dict[2]).group_name) + self.assertEqual("PERSON", GroupResponse(**groups_dict[2]).type) self.assertEqual(1, len(GroupResponse(**groups_dict[2]).entities_ids)) self.assertEqual(1, len(GroupResponse(**groups_dict[2]).entities_text)) self.assertEqual("Maria Rodriguez", GroupResponse(**groups_dict[2]).entities_text[0]) self.assertEqual(2, GroupResponse(**groups_dict[2]).entities_ids[0]) self.assertEqual("Senate", GroupResponse(**groups_dict[3]).group_name) + self.assertEqual("ORGANIZATION", GroupResponse(**groups_dict[3]).type) self.assertEqual(1, len(GroupResponse(**groups_dict[3]).entities_ids)) self.assertEqual(1, len(GroupResponse(**groups_dict[3]).entities_text)) self.assertEqual("Senate", GroupResponse(**groups_dict[3]).entities_text[0]) self.assertEqual(3, GroupResponse(**groups_dict[3]).entities_ids[0]) self.assertEqual("Resolution No. 122", GroupResponse(**groups_dict[4]).group_name) + self.assertEqual("LAW", GroupResponse(**groups_dict[4]).type) self.assertEqual(1, len(GroupResponse(**groups_dict[4]).entities_ids)) self.assertEqual(1, len(GroupResponse(**groups_dict[4]).entities_text)) self.assertEqual("Resolution No. 122", GroupResponse(**groups_dict[4]).entities_text[0]) @@ -130,6 +138,7 @@ def test_text_extraction_for_dates(self): self.assertEqual(2, len(groups_dict)) self.assertEqual("2024-01-13", group_1.group_name) + self.assertEqual("DATE", group_1.type) self.assertEqual(2, len(group_1.entities_ids)) self.assertEqual(2, len(group_1.entities_text)) self.assertEqual("13th of January 2024", group_1.entities_text[0]) @@ -138,29 +147,70 @@ def test_text_extraction_for_dates(self): self.assertEqual(2, group_1.entities_ids[1]) self.assertEqual("13th of February", group_2.group_name) + self.assertEqual("DATE", group_2.type) self.assertEqual(1, len(group_2.entities_ids)) self.assertEqual(1, len(group_2.entities_text)) self.assertEqual("13th of February", group_2.entities_text[0]) self.assertEqual(1, group_2.entities_ids[0]) - # def test_pdf_extraction(self): - # pdf_path: Path = Path(ROOT_PATH, "src", "tests", "end_to_end", "test_pdfs", "test_document.pdf") - # with open(pdf_path, "rb") as pdf_file: - # files = {"file": pdf_file} - # response = requests.post(f"{self.service_url}/pdf", files=files) - # response_json = response.json() - # self.assertEqual(200, response.status_code) - # self.assertEqual(10, len(response_json)) - # self.assertEqual("PERSON", response_json[0]["type"]) - # self.assertEqual("Maria Rodriguez", response_json[0]["text"]) - # self.assertEqual("Maria Rodriguez", response_json[0]["normalized_text"]) - # self.assertEqual(0, response_json[0]["character_start"]) - # self.assertEqual(15, response_json[0]["character_end"]) - # expected_segment_text: str = ( - # "Maria Rodriguez visited the Louvre Museum in Paris, France, on Wednesday, July 12, 2023." - # ) - # self.assertEqual(expected_segment_text, response_json[0]["segment_text"]) - # self.assertEqual(1, response_json[0]["page_number"]) - # self.assertEqual(1, response_json[0]["segment_number"]) - # self.assertEqual(72, response_json[0]["bounding_box"]["left"]) - # self.assertEqual(74, response_json[0]["bounding_box"]["top"]) + def test_pdf_extraction(self): + pdf_path: Path = Path(ROOT_PATH, "src", "tests", "end_to_end", "test_pdfs", "test_document.pdf") + with open(pdf_path, "rb") as pdf_file: + files = {"file": pdf_file} + response = requests.post(f"{self.service_url}/pdf", files=files) + + entities_dict = response.json()["entities"] + groups_dict = response.json()["groups"] + + entity_0 = PDFNamedEntityResponse(**entities_dict[0]) + entity_9 = PDFNamedEntityResponse(**entities_dict[9]) + group_0 = GroupResponse(**groups_dict[0]) + group_7 = GroupResponse(**groups_dict[7]) + + self.assertEqual(200, response.status_code) + + self.assertEqual(10, len(entities_dict)) + + self.assertEqual("Maria Diaz Rodriguez", entity_0.group_name) + self.assertEqual("PERSON", entity_0.type) + self.assertEqual("Maria Rodriguez", entity_0.text) + segment_text: str = "Maria Rodriguez visited the Louvre Museum in Paris, France, on Wednesday, July 12, 2023." + self.assertEqual(segment_text, entity_0.segment.text) + self.assertEqual(1, entity_0.segment.page_number) + self.assertEqual(1, entity_0.segment.segment_number) + self.assertEqual(0, entity_0.segment.character_start) + self.assertEqual(15, entity_0.segment.character_end) + self.assertEqual(72, entity_0.segment.bounding_box.left) + self.assertEqual(74, entity_0.segment.bounding_box.top) + self.assertEqual(430, entity_0.segment.bounding_box.width) + self.assertEqual(34, entity_0.segment.bounding_box.height) + + self.assertEqual("Resolution No. 122", entity_9.text) + self.assertEqual("LAW", entity_9.type) + self.assertEqual("Resolution No. 122", entity_9.group_name) + segment_text: str = "The Senate passed Resolution No. 122, establishing a set of rules for the impeachment trial." + self.assertEqual(segment_text, entity_9.segment.text) + self.assertEqual(1, entity_9.segment.page_number) + self.assertEqual(5, entity_9.segment.segment_number) + self.assertEqual(18, entity_9.segment.character_start) + self.assertEqual(36, entity_9.segment.character_end) + self.assertEqual(72, entity_9.segment.bounding_box.left) + self.assertEqual(351, entity_9.segment.bounding_box.top) + self.assertEqual(440, entity_9.segment.bounding_box.width) + self.assertEqual(35, entity_9.segment.bounding_box.height) + + self.assertEqual(8, len(groups_dict)) + + self.assertEqual("Maria Diaz Rodriguez", group_0.group_name) + self.assertEqual("PERSON", group_0.type) + self.assertEqual(3, len(group_0.entities_ids)) + self.assertEqual(3, len(group_0.entities_text)) + self.assertEqual([0, 5, 6], group_0.entities_ids) + self.assertEqual(["Maria Rodriguez", "Maria Diaz Rodriguez", "M.D. Rodriguez"], group_0.entities_text) + + self.assertEqual("Resolution No. 122", group_7.group_name) + self.assertEqual("LAW", group_7.type) + self.assertEqual(1, len(group_7.entities_ids)) + self.assertEqual(1, len(group_7.entities_text)) + self.assertEqual(["Resolution No. 122"], group_7.entities_text) + self.assertEqual([9], group_7.entities_ids)