Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(prompt-lab): app shouldn't crash when no prompts found #53

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gradio as gr
import jinja2
from pydantic import BaseModel
from rich.console import Console

from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
Expand Down Expand Up @@ -38,28 +39,6 @@ class PromptState:
temp_field_name: str = ""


def load_prompts_list(pattern: str, state: gr.State) -> gr.State:
"""
Fetches a list of prompts based on provided paths and updates the application state.

This function takes a path-pattern for discovering prompt definition files and uses the
PromptDiscovery class to discover prompts within those files. The discovered prompts are then
stored in the application state object.

Args:
pattern (str): A pattern for looking up prompt files.
state (gr.State): The Gradio state object to update with discovered prompts.

Returns:
gr.State: The updated Gradio state object containing the list of discovered prompts.
"""
obj = PromptDiscovery(file_pattern=pattern)
discovered_prompts = list(obj.discover())
state.value.prompts = discovered_prompts

return state


def render_prompt(
index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any
) -> tuple[str, str, gr.State]:
Expand Down Expand Up @@ -91,9 +70,7 @@ def render_prompt(
prompt_object = prompt_class(input_data=input_data)
state.current_prompt = prompt_object

chat_dict = {entry["role"]: entry["content"] for entry in prompt_object.chat}

return chat_dict["system"], chat_dict["user"], state
return prompt_object.system_message, prompt_object.user_message, state


def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]:
Expand Down Expand Up @@ -165,13 +142,24 @@ def lab_app( # pylint: disable=missing-param-doc
Launches the interactive application for listing, rendering, and testing prompts
defined within the current project.
"""
prompts = PromptDiscovery(file_pattern=file_pattern).discover()

if not prompts:
Console(stderr=True).print(
f"""No prompts were found for the given file pattern: [b]{file_pattern}[/b].

Please make sure that you are executing the command from the correct directory \
or provide a custom file pattern using the [b]--file-pattern[/b] flag."""
)
return

with gr.Blocks() as gr_app:
prompt_state_obj = PromptState()
prompt_state_obj.llm_model_name = llm_model
prompt_state_obj.llm_api_key = llm_api_key
prompt_state_obj.prompts = list(prompts)

prompts_state = gr.State(value=prompt_state_obj)
prompts_state = load_prompts_list(pattern=file_pattern, state=prompts_state)
prompt_selection_dropdown = gr.Dropdown(
choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt"
)
Expand Down
Loading