From 7a9b3a4186e9247136742a4ff98b523fe2bfb66c Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Wed, 2 Oct 2024 10:51:53 +0200 Subject: [PATCH] fix(prompt-lab): app shouldn't crash when no prompts found --- .../src/ragbits/dev_kit/prompt_lab/app.py | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py index 9e9d5776..a6db0b90 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py @@ -4,6 +4,7 @@ import gradio as gr import jinja2 from pydantic import BaseModel +from rich.console import Console from ragbits.core.llms import LiteLLM from ragbits.core.llms.clients import LiteLLMOptions @@ -38,28 +39,6 @@ class PromptState: temp_field_name: str = "" -def load_prompts_list(pattern: str, state: gr.State) -> gr.State: - """ - Fetches a list of prompts based on provided paths and updates the application state. - - This function takes a path-pattern for discovering prompt definition files and uses the - PromptDiscovery class to discover prompts within those files. The discovered prompts are then - stored in the application state object. - - Args: - pattern (str): A pattern for looking up prompt files. - state (gr.State): The Gradio state object to update with discovered prompts. - - Returns: - gr.State: The updated Gradio state object containing the list of discovered prompts. - """ - obj = PromptDiscovery(file_pattern=pattern) - discovered_prompts = list(obj.discover()) - state.value.prompts = discovered_prompts - - return state - - def render_prompt( index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any ) -> tuple[str, str, gr.State]: @@ -91,9 +70,7 @@ def render_prompt( prompt_object = prompt_class(input_data=input_data) state.current_prompt = prompt_object - chat_dict = {entry["role"]: entry["content"] for entry in prompt_object.chat} - - return chat_dict["system"], chat_dict["user"], state + return prompt_object.system_message, prompt_object.user_message, state def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]: @@ -165,13 +142,24 @@ def lab_app( # pylint: disable=missing-param-doc Launches the interactive application for listing, rendering, and testing prompts defined within the current project. """ + prompts = PromptDiscovery(file_pattern=file_pattern).discover() + + if not prompts: + Console(stderr=True).print( + f"""No prompts were found for the given file pattern: [b]{file_pattern}[/b]. + +Please make sure that you are executing the command from the correct directory \ +or provide a custom file pattern using the [b]--file-pattern[/b] flag.""" + ) + return + with gr.Blocks() as gr_app: prompt_state_obj = PromptState() prompt_state_obj.llm_model_name = llm_model prompt_state_obj.llm_api_key = llm_api_key + prompt_state_obj.prompts = list(prompts) prompts_state = gr.State(value=prompt_state_obj) - prompts_state = load_prompts_list(pattern=file_pattern, state=prompts_state) prompt_selection_dropdown = gr.Dropdown( choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt" )