feat(ai-chat): Add SQL query field for comparison

In order to compare the (not yet implemented) SQL query generated by
the LLM with an actual query, another text field was added that parses
the query to `pyodbc`, which connects to our database, stores the
resulting rows in a `pandas` dataframe and then visualizes it as a table
in plotly dash.

The SQL functionalities are implemented in the `sql_utils.py` module.

Additionally, some minor updates to the overall behavior and layout of
the app were implemented.
This commit is contained in:
Tobias Quadfasel
2024-09-02 20:43:48 +02:00
parent 923dc3b439
commit 4b9fa0579e
4 changed files with 272 additions and 57 deletions

View File

@@ -6,6 +6,7 @@ from dash import (
Output,
State,
callback,
dash_table,
dcc,
get_asset_url,
html,
@@ -15,12 +16,14 @@ from dash.exceptions import PreventUpdate
from .app_styles import header_style
from .data_chat import send_message
from .sql_utils import execute_query, test_db_connection
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets)
app.index_string = header_style
err_style = {
"height": "0px",
"overflow": "hidden",
@@ -38,63 +41,114 @@ err_style = {
"align-items": "center",
}
start_value = "Stelle deine Frage an die Datenbank..."
app.layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="session"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=err_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
], # Output field
className="container",
)
def get_layout() -> html.Div:
"""Generate the layout for a Dash application.
This function creates a complex layout for a database chat interface
and a direct SQL query interface. It includes various Dash components
such as text areas, buttons, and a loading spinner.
The layout consists of:
- A header with title and logo
- A textarea for user input (database chat)
- A submit button for the database chat
- An error message area
- A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button
- An output area for SQL query results
Returns
-------
html.Div
A Dash html.Div component containing the entire layout of the application.
"""
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "0px"
start_value = "Stelle deine Frage an die Datenbank..."
layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
html.H2(
"Direkte SQL-Abfrage",
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
), # SQL Header
dcc.Textarea(
id="sql-input-field",
value="(Microsoft) SQL-Abfrage eingeben...",
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # SQL Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="sql-submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
],
className="container",
)
return layout
app.layout = html.Div([get_layout()])
@callback(
Output("text-output", "children"),
Output("tmp-value", "data"),
Output("error", "style"),
Output("error", "children"),
Input("submit-button", "n_clicks"),
State("input-field", "value"),
State("tmp-value", "data"),
@@ -108,7 +162,7 @@ app.layout = html.Div(
),
],
)
def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[str, Any]]:
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
"""Update the output based on user input and button clicks.
Parameters
@@ -126,18 +180,91 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[
Updated output text, new stored value, and error style.
"""
global err_style
if n_clicks > 0 and value != data:
print(f"Value: {value}")
print(f"Data: {data}")
db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected:
result = send_message(value)
err_style["height"] = "0px"
return result, value, err_style
return result, value, err_style, html.P("")
elif value == data:
err_style["height"] = "50px"
return no_update, no_update, err_style
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
elif not db_connected:
err_style["height"] = "50px"
err_child = html.P(
(
"Fehler beim Herstellen der Verbindung zur "
"Datenbank. Bitte versuche es später erneut."
)
)
return no_update, no_update, err_style, err_child
raise PreventUpdate
@callback(
Output("sql-output", "children"),
Input("sql-submit-button", "n_clicks"),
State("sql-input-field", "value"),
prevent_initial_call=True,
)
def run_sql_query(n_clicks: int, value: str) -> str:
"""Run a SQL query and return the results.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
Returns
-------
str
The results of the SQL query.
"""
if n_clicks > 0:
result = execute_query(value)
if isinstance(result, str):
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "80px"
tmp_style["padding"] = "20px"
err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
)
return err_child
else:
table_child = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in result.columns],
data=result.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return table_child
raise PreventUpdate
server = app.server
if __name__ == "__main__":