diff --git a/forte_wrapper/allennlp/allennlp_processors.py b/forte_wrapper/allennlp/allennlp_processors.py index 53b87fa..7d58345 100644 --- a/forte_wrapper/allennlp/allennlp_processors.py +++ b/forte_wrapper/allennlp/allennlp_processors.py @@ -14,7 +14,7 @@ import itertools import logging -from typing import Any, Dict, Iterable, Iterator, List +from typing import Any, Dict, Iterable, Iterator, List, Set import more_itertools from allennlp.predictors import Predictor @@ -50,6 +50,12 @@ class AllenNLPProcessor(PackProcessor): # pylint: disable=attribute-defined-outside-init,unused-argument def initialize(self, resources: Resources, configs: Config): super().initialize(resources, configs) + if ("pos" in configs.processors or "depparse" in configs.processors + or "depparse" in configs.processors): + if "tokenize" not in self.configs.processors: + raise ProcessorConfigError('tokenize is necessary in ' + 'configs.processors for ' + 'pos, depparse or srl') cuda_devices = itertools.cycle(configs['cuda_devices']) if configs.tag_formalism not in MODEL2URL: raise ProcessorConfigError('Incorrect value for tag_formalism') @@ -225,3 +231,35 @@ def _create_srl(input_pack: DataPack, tokens: List[Token], tokens[arg_span.end].end) link = PredicateLink(input_pack, pred, arg) link.arg_type = label + + @classmethod + def expected_types_and_attributes(cls) -> Dict[str, Set[str]]: + r"""Method to add expected type for current processor input which + would be checked before running the processor if + :meth:`~forte.pipeline.Pipeline.enforce_consistency` was enabled for + the pipeline. + """ + expectation_dict: Dict[str, Set[str]] = { + "ft.onto.base_ontology.Sentence": set() + } + return expectation_dict + + def record(self, record_meta: Dict[str, Set[str]]): + r"""Method to add output type record of current processor + to :attr:`forte.data.data_pack.Meta.record`. + + Args: + record_meta: the field in the datapack for type record that need to + fill in for consistency checking. + """ + if "tokenize" in self.configs.processors: + record_meta["ft.onto.base_ontology.Token"] = set() + if "pos" in self.configs.processors: + record_meta["ft.onto.base_ontology.Token"].add("pos") + if "depparse" in self.configs.processors: + record_meta["ft.onto.base_ontology.Dependency"] = {"rel_type"} + if "srl" in self.configs.processors: + record_meta["ft.onto.base_ontology.PredicateArgument"] = set() + record_meta["ft.onto.base_ontology.PredicateMention"] = set() + record_meta["ft.onto.base_ontology.PredicateLink"] = \ + {"arg_type"} diff --git a/forte_wrapper/spacy/spacy_processors.py b/forte_wrapper/spacy/spacy_processors.py index e580c4a..1662ef8 100644 --- a/forte_wrapper/spacy/spacy_processors.py +++ b/forte_wrapper/spacy/spacy_processors.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Dict, Set import spacy from spacy.language import Language from spacy.cli.download import download -from forte.common import ProcessExecutionException +from forte.common import ProcessExecutionException, ProcessorConfigError from forte.common.configuration import Config from forte.common.resources import Resources from forte.data.data_pack import DataPack @@ -48,6 +48,11 @@ def set_up(self): # pylint: disable=unused-argument def initialize(self, resources: Resources, configs: Config): + if "pos" in configs.processors or "lemma" in configs.processors: + if "tokenize" not in configs.processors: + raise ProcessorConfigError('tokenize is necessary in ' + 'configs.processors for ' + 'pos or lemma') self.processors = configs.processors self.lang_model = configs.lang self.set_up() @@ -125,3 +130,20 @@ def _process(self, input_pack: DataPack): # Process sentence parses. self._process_parser(result.sents, input_pack) + + def record(self, record_meta: Dict[str, Set[str]]): + r"""Method to add output type record of current processor + to :attr:`forte.data.data_pack.Meta.record`. + + Args: + record_meta: the field in the datapack for type record that need to + fill in for consistency checking. + """ + record_meta["ft.onto.base_ontology.Sentence"] = set() + record_meta["ft.onto.base_ontology.EntityMention"] = set() + if "tokenize" in self.processors: + record_meta["ft.onto.base_ontology.Token"] = set() + if "pos" in self.processors: + record_meta["ft.onto.base_ontology.Token"].add("pos") + if "lemma" in self.processors: + record_meta["ft.onto.base_ontology.Token"].add("lemma")