From aeeedee44fa284fd726b1dd017e75b4c28188e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Alt=C4=B1parmak?= Date: Wed, 15 Jan 2025 14:36:56 +0300 Subject: [PATCH] Fix Flair model load --- src/download_models.py | 2 +- src/use_cases/GetFlairEntitiesUseCase.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/download_models.py b/src/download_models.py index 3bc4dad..de90a1c 100644 --- a/src/download_models.py +++ b/src/download_models.py @@ -1,7 +1,7 @@ import math from os import makedirs from os.path import join, exists -from huggingface_hub import snapshot_download, hf_hub_download +from huggingface_hub import snapshot_download from configuration import MODELS_PATH diff --git a/src/use_cases/GetFlairEntitiesUseCase.py b/src/use_cases/GetFlairEntitiesUseCase.py index 83d4b82..461137e 100644 --- a/src/use_cases/GetFlairEntitiesUseCase.py +++ b/src/use_cases/GetFlairEntitiesUseCase.py @@ -1,11 +1,11 @@ from pathlib import Path -from flair.models import SequenceTagger -from configuration import ROOT_PATH +from flair.nn import Classifier +from configuration import MODELS_PATH from domain.NamedEntity import NamedEntity from flair.data import Sentence, Span from domain.NamedEntityType import NamedEntityType -flair_model = SequenceTagger.load(Path(ROOT_PATH, "models", "flair", "pytorch_model.bin")) +flair_model = Classifier.load(Path(MODELS_PATH, "flair", "pytorch_model.bin")) class GetFlairEntitiesUseCase: