diff --git a/prepare/tasks/schema_linking.py b/prepare/tasks/schema_linking.py new file mode 100644 index 0000000000..39a5d3fdf1 --- /dev/null +++ b/prepare/tasks/schema_linking.py @@ -0,0 +1,20 @@ +from typing import List + +from unitxt.blocks import Task +from unitxt.catalog import add_to_catalog + +add_to_catalog( + Task( + input_fields={ + "id": str, + "utterance": str, + "hint": str, + "schema": List[str], + }, + reference_fields={"linked_schema": List[str]}, + prediction_type=List[str], + metrics=["metrics.f1_macro_multi_label"], + ), + "tasks.schema_linking", + overwrite=True, +) diff --git a/prepare/templates/schema_linking/templates.py b/prepare/templates/schema_linking/templates.py new file mode 100644 index 0000000000..e7ff5c0adb --- /dev/null +++ b/prepare/templates/schema_linking/templates.py @@ -0,0 +1,32 @@ +from unitxt import add_to_catalog +from unitxt.templates import InputOutputTemplate, TemplatesList + +add_to_catalog( + InputOutputTemplate( + instruction="Select the most relevant SQL columns to the given text.", + input_format="Text: {utterance}\n\nColumns:{schema}", + output_format="{linked_schema}", + postprocessors=["processors.to_list_by_comma_space"], + ), + "templates.schema_linking.default", + overwrite=True, +) + +add_to_catalog( + InputOutputTemplate( + instruction="Select the most relevant SQL columns to the given text. You are also given a hint.", + input_format="Text: {utterance}\n\nHint: {hint}\n\nColumns:{schema}", + output_format="{linked_schema}", + postprocessors=["processors.to_list_by_comma_space"], + ), + "templates.schema_linking.with_hint", + overwrite=True, +) + +add_to_catalog( + TemplatesList( + ["templates.schema_linking.default", "templates.schema_linking.with_hint"] + ), + "templates.schema_linking.all", + overwrite=True, +) diff --git a/src/unitxt/catalog/tasks/schema_linking.json b/src/unitxt/catalog/tasks/schema_linking.json new file mode 100644 index 0000000000..fea5f3fcf4 --- /dev/null +++ b/src/unitxt/catalog/tasks/schema_linking.json @@ -0,0 +1,16 @@ +{ + "__type__": "task", + "input_fields": { + "id": "str", + "utterance": "str", + "hint": "str", + "schema": "List[str]" + }, + "reference_fields": { + "linked_schema": "List[str]" + }, + "prediction_type": "List[str]", + "metrics": [ + "metrics.f1_macro_multi_label" + ] +} diff --git a/src/unitxt/catalog/templates/schema_linking/all.json b/src/unitxt/catalog/templates/schema_linking/all.json new file mode 100644 index 0000000000..55bf250dc8 --- /dev/null +++ b/src/unitxt/catalog/templates/schema_linking/all.json @@ -0,0 +1,7 @@ +{ + "__type__": "templates_list", + "items": [ + "templates.schema_linking.default", + "templates.schema_linking.with_hint" + ] +} diff --git a/src/unitxt/catalog/templates/schema_linking/default.json b/src/unitxt/catalog/templates/schema_linking/default.json new file mode 100644 index 0000000000..88e49c2d0b --- /dev/null +++ b/src/unitxt/catalog/templates/schema_linking/default.json @@ -0,0 +1,9 @@ +{ + "__type__": "input_output_template", + "instruction": "Select the most relevant SQL columns to the given text.", + "input_format": "Text: {utterance}\n\nColumns:{schema}", + "output_format": "{linked_schema}", + "postprocessors": [ + "processors.to_list_by_comma_space" + ] +} diff --git a/src/unitxt/catalog/templates/schema_linking/with_hint.json b/src/unitxt/catalog/templates/schema_linking/with_hint.json new file mode 100644 index 0000000000..b7a29d3e32 --- /dev/null +++ b/src/unitxt/catalog/templates/schema_linking/with_hint.json @@ -0,0 +1,9 @@ +{ + "__type__": "input_output_template", + "instruction": "Select the most relevant SQL columns to the given text. You are also given a hint.", + "input_format": "Text: {utterance}\n\nHint: {hint}\n\nColumns:{schema}", + "output_format": "{linked_schema}", + "postprocessors": [ + "processors.to_list_by_comma_space" + ] +}