Files
grid_application/app/app.py
Tobias Quadfasel 4b9fa0579e feat(ai-chat): Add SQL query field for comparison
In order to compare the (not yet implemented) SQL query generated by
the LLM with an actual query, another text field was added that parses
the query to `pyodbc`, which connects to our database, stores the
resulting rows in a `pandas` dataframe and then visualizes it as a table
in plotly dash.

The SQL functionalities are implemented in the `sql_utils.py` module.

Additionally, some minor updates to the overall behavior and layout of
the app were implemented.
2024-09-02 20:43:48 +02:00

272 lines
8.3 KiB
Python

from typing import Any, Dict, Tuple
from dash import (
Dash,
Input,
Output,
State,
callback,
dash_table,
dcc,
get_asset_url,
html,
no_update,
)
from dash.exceptions import PreventUpdate
from .app_styles import header_style
from .data_chat import send_message
from .sql_utils import execute_query, test_db_connection
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets)
app.index_string = header_style
err_style = {
"height": "0px",
"overflow": "hidden",
"transition": "height 0.5s ease-in-out",
"border-radius": "15px",
"background-color": "#FFCCCB",
"text-align": "center",
"color": "#FF6B6B",
"margin-top": "20px",
"margin-left": "20px",
"margin-right": "20px",
"font-weight": "bold",
"display": "flex",
"justify-content": "center",
"align-items": "center",
}
def get_layout() -> html.Div:
"""Generate the layout for a Dash application.
This function creates a complex layout for a database chat interface
and a direct SQL query interface. It includes various Dash components
such as text areas, buttons, and a loading spinner.
The layout consists of:
- A header with title and logo
- A textarea for user input (database chat)
- A submit button for the database chat
- An error message area
- A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button
- An output area for SQL query results
Returns
-------
html.Div
A Dash html.Div component containing the entire layout of the application.
"""
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "0px"
start_value = "Stelle deine Frage an die Datenbank..."
layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
html.H2(
"Direkte SQL-Abfrage",
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
), # SQL Header
dcc.Textarea(
id="sql-input-field",
value="(Microsoft) SQL-Abfrage eingeben...",
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # SQL Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="sql-submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
],
className="container",
)
return layout
app.layout = html.Div([get_layout()])
@callback(
Output("text-output", "children"),
Output("tmp-value", "data"),
Output("error", "style"),
Output("error", "children"),
Input("submit-button", "n_clicks"),
State("input-field", "value"),
State("tmp-value", "data"),
prevent_initial_call=True,
running=[
(Output("submit-button", "disabled"), True, False),
(
Output("submit-button", "style"),
{"opacity": 0.5, "margin-left": "20px"},
{"opacity": 1.0, "margin-left": "20px"},
),
],
)
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
"""Update the output based on user input and button clicks.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
data : str
Previously stored value.
Returns
-------
Tuple[str, str, Dict[str, Any]]
Updated output text, new stored value, and error style.
"""
global err_style
print(f"Value: {value}")
print(f"Data: {data}")
db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected:
result = send_message(value)
err_style["height"] = "0px"
return result, value, err_style, html.P("")
elif value == data:
err_style["height"] = "50px"
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
elif not db_connected:
err_style["height"] = "50px"
err_child = html.P(
(
"Fehler beim Herstellen der Verbindung zur "
"Datenbank. Bitte versuche es später erneut."
)
)
return no_update, no_update, err_style, err_child
raise PreventUpdate
@callback(
Output("sql-output", "children"),
Input("sql-submit-button", "n_clicks"),
State("sql-input-field", "value"),
prevent_initial_call=True,
)
def run_sql_query(n_clicks: int, value: str) -> str:
"""Run a SQL query and return the results.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
Returns
-------
str
The results of the SQL query.
"""
if n_clicks > 0:
result = execute_query(value)
if isinstance(result, str):
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "80px"
tmp_style["padding"] = "20px"
err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
)
return err_child
else:
table_child = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in result.columns],
data=result.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return table_child
raise PreventUpdate
server = app.server
if __name__ == "__main__":
app.run(debug=True)