From 0c4cfda60b1e3224ceee1e67934cac5705b71d71 Mon Sep 17 00:00:00 2001 From: KyleErwin Date: Tue, 28 Jan 2025 09:55:33 +0200 Subject: [PATCH] add schema linking Signed-off-by: KyleErwin --- prepare/tasks/schema_linking.py | 20 ++++++++++++ prepare/templates/schema_linking/templates.py | 32 +++++++++++++++++++ src/unitxt/catalog/tasks/schema_linking.json | 16 ++++++++++ .../catalog/templates/schema_linking/all.json | 7 ++++ .../templates/schema_linking/default.json | 9 ++++++ .../templates/schema_linking/with_hint.json | 9 ++++++ 6 files changed, 93 insertions(+) create mode 100644 prepare/tasks/schema_linking.py create mode 100644 prepare/templates/schema_linking/templates.py create mode 100644 src/unitxt/catalog/tasks/schema_linking.json create mode 100644 src/unitxt/catalog/templates/schema_linking/all.json create mode 100644 src/unitxt/catalog/templates/schema_linking/default.json create mode 100644 src/unitxt/catalog/templates/schema_linking/with_hint.json diff --git a/prepare/tasks/schema_linking.py b/prepare/tasks/schema_linking.py new file mode 100644 index 0000000000..39a5d3fdf1 --- /dev/null +++ b/prepare/tasks/schema_linking.py @@ -0,0 +1,20 @@ +from typing import List + +from unitxt.blocks import Task +from unitxt.catalog import add_to_catalog + +add_to_catalog( + Task( + input_fields={ + "id": str, + "utterance": str, + "hint": str, + "schema": List[str], + }, + reference_fields={"linked_schema": List[str]}, + prediction_type=List[str], + metrics=["metrics.f1_macro_multi_label"], + ), + "tasks.schema_linking", + overwrite=True, +) diff --git a/prepare/templates/schema_linking/templates.py b/prepare/templates/schema_linking/templates.py new file mode 100644 index 0000000000..e7ff5c0adb --- /dev/null +++ b/prepare/templates/schema_linking/templates.py @@ -0,0 +1,32 @@ +from unitxt import add_to_catalog +from unitxt.templates import InputOutputTemplate, TemplatesList + +add_to_catalog( + InputOutputTemplate( + instruction="Select the most relevant SQL columns to the given text.", + input_format="Text: {utterance}\n\nColumns:{schema}", + output_format="{linked_schema}", + postprocessors=["processors.to_list_by_comma_space"], + ), + "templates.schema_linking.default", + overwrite=True, +) + +add_to_catalog( + InputOutputTemplate( + instruction="Select the most relevant SQL columns to the given text. You are also given a hint.", + input_format="Text: {utterance}\n\nHint: {hint}\n\nColumns:{schema}", + output_format="{linked_schema}", + postprocessors=["processors.to_list_by_comma_space"], + ), + "templates.schema_linking.with_hint", + overwrite=True, +) + +add_to_catalog( + TemplatesList( + ["templates.schema_linking.default", "templates.schema_linking.with_hint"] + ), + "templates.schema_linking.all", + overwrite=True, +) diff --git a/src/unitxt/catalog/tasks/schema_linking.json b/src/unitxt/catalog/tasks/schema_linking.json new file mode 100644 index 0000000000..fea5f3fcf4 --- /dev/null +++ b/src/unitxt/catalog/tasks/schema_linking.json @@ -0,0 +1,16 @@ +{ + "__type__": "task", + "input_fields": { + "id": "str", + "utterance": "str", + "hint": "str", + "schema": "List[str]" + }, + "reference_fields": { + "linked_schema": "List[str]" + }, + "prediction_type": "List[str]", + "metrics": [ + "metrics.f1_macro_multi_label" + ] +} diff --git a/src/unitxt/catalog/templates/schema_linking/all.json b/src/unitxt/catalog/templates/schema_linking/all.json new file mode 100644 index 0000000000..55bf250dc8 --- /dev/null +++ b/src/unitxt/catalog/templates/schema_linking/all.json @@ -0,0 +1,7 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.schema_linking.default", + "templates.schema_linking.with_hint" + ] +} diff --git a/src/unitxt/catalog/templates/schema_linking/default.json b/src/unitxt/catalog/templates/schema_linking/default.json new file mode 100644 index 0000000000..88e49c2d0b --- /dev/null +++ b/src/unitxt/catalog/templates/schema_linking/default.json @@ -0,0 +1,9 @@ +{ + "__type__": "input_output_template", + "instruction": "Select the most relevant SQL columns to the given text.", + "input_format": "Text: {utterance}\n\nColumns:{schema}", + "output_format": "{linked_schema}", + "postprocessors": [ + "processors.to_list_by_comma_space" + ] +} diff --git a/src/unitxt/catalog/templates/schema_linking/with_hint.json b/src/unitxt/catalog/templates/schema_linking/with_hint.json new file mode 100644 index 0000000000..b7a29d3e32 --- /dev/null +++ b/src/unitxt/catalog/templates/schema_linking/with_hint.json @@ -0,0 +1,9 @@ +{ + "__type__": "input_output_template", + "instruction": "Select the most relevant SQL columns to the given text. You are also given a hint.", + "input_format": "Text: {utterance}\n\nHint: {hint}\n\nColumns:{schema}", + "output_format": "{linked_schema}", + "postprocessors": [ + "processors.to_list_by_comma_space" + ] +}