Add the first working code logic both in terms of backend and frontend-related tasks. Add a detailled system message for improved results. Add several UI improvements for result display and user information. Add text input field for direct SQL code comparison. The implementation of the openAI backend had to be changed due to strict rate limits of azure OpenAI free tier and was replaced with a regular openai API key.
348 lines
11 KiB
Python
348 lines
11 KiB
Python
import json
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import pandas as pd
|
|
from app_styles import header_style
|
|
from dash import (
|
|
Dash,
|
|
Input,
|
|
Output,
|
|
State,
|
|
callback,
|
|
dash_table,
|
|
dcc,
|
|
get_asset_url,
|
|
html,
|
|
no_update,
|
|
)
|
|
from dash.exceptions import PreventUpdate
|
|
from data_chat import send_message
|
|
from sql_utils import execute_query, test_db_connection
|
|
|
|
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
|
|
|
app = Dash(__name__, external_stylesheets=external_stylesheets)
|
|
app.index_string = header_style
|
|
|
|
notification_md = """
|
|
**Hinweise:**
|
|
|
|
Aufgrund des sparsamen pricing Tiers kann es einige Sekunden dauern, bis die
|
|
Verbindung zur Datenbank hergestellt wird. Im Falle eines Fehlers gern ein-zwei mal erneut
|
|
versuchen.
|
|
|
|
GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt.
|
|
In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl
|
|
zusätzliche Informationen zu geben.
|
|
|
|
Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen.
|
|
|
|
Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen.
|
|
|
|
**Beispielfragen**:
|
|
- Wie viele Kunden haben wir in Hannover?
|
|
- Zeige alle Kunden in Bremen.
|
|
- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg.
|
|
- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben.
|
|
- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022
|
|
und 2023?
|
|
|
|
Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im
|
|
[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application).
|
|
"""
|
|
|
|
err_style = {
|
|
"height": "0px",
|
|
"overflow": "hidden",
|
|
"transition": "height 0.5s ease-in-out",
|
|
"border-radius": "15px",
|
|
"background-color": "#FFCCCB",
|
|
"text-align": "center",
|
|
"color": "#FF6B6B",
|
|
"margin-top": "20px",
|
|
"margin-left": "20px",
|
|
"margin-right": "20px",
|
|
"font-weight": "bold",
|
|
"display": "flex",
|
|
"justify-content": "center",
|
|
"align-items": "center",
|
|
}
|
|
|
|
|
|
def render_table(df: pd.DataFrame) -> dash_table.DataTable:
|
|
"""Create a Dash DataTable from a pandas DataFrame.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
The input DataFrame to be rendered as a table.
|
|
|
|
Returns
|
|
-------
|
|
dash_table.DataTable
|
|
A Dash DataTable component with styled layout and pagination.
|
|
"""
|
|
tab = dash_table.DataTable(
|
|
id="table",
|
|
columns=[{"name": i, "id": i} for i in df.columns],
|
|
data=df.to_dict("records"),
|
|
page_size=10,
|
|
style_table={
|
|
"overflowX": "auto",
|
|
"margin": "auto",
|
|
"width": "96%",
|
|
"margin-top": "20px",
|
|
},
|
|
style_cell={
|
|
"minWidth": "100px",
|
|
"width": "150px",
|
|
"maxWidth": "300px",
|
|
"overflow": "hidden",
|
|
"textOverflow": "ellipsis",
|
|
},
|
|
style_header={
|
|
"backgroundColor": "lightblue",
|
|
"fontWeight": "bold",
|
|
"color": "black",
|
|
},
|
|
)
|
|
return tab
|
|
|
|
|
|
def get_layout() -> html.Div:
|
|
"""Generate the layout for a Dash application.
|
|
|
|
This function creates a complex layout for a database chat interface
|
|
and a direct SQL query interface. It includes various Dash components
|
|
such as text areas, buttons, and a loading spinner.
|
|
|
|
The layout consists of:
|
|
- A header with title and logo
|
|
- A textarea for user input (database chat)
|
|
- A submit button for the database chat
|
|
- A Notification about the connection time and LLM performance
|
|
- An error message area
|
|
- A loading spinner and output area for database responses
|
|
- A section for direct SQL queries, including a textarea and submit button
|
|
- An output area for SQL query results
|
|
|
|
Returns
|
|
-------
|
|
html.Div
|
|
A Dash html.Div component containing the entire layout of the application.
|
|
"""
|
|
global err_style
|
|
tmp_style = err_style.copy()
|
|
tmp_style["height"] = "0px"
|
|
|
|
start_value = "Stelle deine Frage an die Datenbank..."
|
|
|
|
layout = html.Div(
|
|
[
|
|
html.Div(
|
|
[
|
|
html.H1("Datenbank-Chat", className="heading"),
|
|
html.Img(src=get_asset_url("logo.png"), className="logo"),
|
|
],
|
|
className="header-container",
|
|
), # Header
|
|
html.Div(
|
|
"Ganz ohne SQL-Kenntnisse Daten zu Zählerstandmessungen unserer Kunden abrufen!",
|
|
style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"},
|
|
),
|
|
dcc.Store(
|
|
id="tmp-value", data=start_value, storage_type="memory"
|
|
), # Store previous prompt
|
|
dcc.Textarea(
|
|
id="input-field",
|
|
value=start_value,
|
|
style={"width": "96%", "height": 200, "margin-left": "20px"},
|
|
), # Input field
|
|
html.Div([]), # Needed for keeping the layout clean
|
|
html.Button(
|
|
"Abschicken",
|
|
id="submit-button",
|
|
n_clicks=0,
|
|
disabled=False,
|
|
style={"margin-left": "20px"},
|
|
), # Submit button
|
|
dcc.Markdown(
|
|
notification_md,
|
|
style={
|
|
"margin-left": "20px",
|
|
"margin-top": "20px",
|
|
"margin-right": "10px",
|
|
"background-color": "#C7E6F5",
|
|
"border-radius": "5px",
|
|
"padding": "10px",
|
|
},
|
|
),
|
|
html.Div(
|
|
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
|
|
), # Error message (only visible if input is not updated but submit button is clicked)
|
|
dcc.Loading(
|
|
id="loading",
|
|
type="default",
|
|
children=[
|
|
html.Div(
|
|
"Hier erscheint die Antwort des KI-Modells.",
|
|
id="text-output",
|
|
style={
|
|
"whiteSpace": "pre-line",
|
|
"margin-top": "30px",
|
|
"margin-left": "20px",
|
|
"margin-right": "20px",
|
|
"border": "2px solid #86bc25",
|
|
"border-radius": "15px",
|
|
"padding": "20px",
|
|
},
|
|
)
|
|
],
|
|
),
|
|
html.H2(
|
|
"Direkte SQL-Abfrage",
|
|
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
|
|
), # SQL Header
|
|
dcc.Textarea(
|
|
id="sql-input-field",
|
|
value="(Microsoft) SQL-Abfrage eingeben...",
|
|
style={"width": "96%", "height": 200, "margin-left": "20px"},
|
|
), # SQL Input field
|
|
html.Div([]), # Needed for keeping the layout clean
|
|
html.Button(
|
|
"Abschicken",
|
|
id="sql-submit-button",
|
|
n_clicks=0,
|
|
disabled=False,
|
|
style={"margin-left": "20px"},
|
|
), # Submit button
|
|
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
|
|
],
|
|
className="container",
|
|
)
|
|
|
|
return layout
|
|
|
|
|
|
app.layout = html.Div([get_layout()])
|
|
|
|
|
|
@callback(
|
|
Output("text-output", "children"),
|
|
Output("tmp-value", "data"),
|
|
Output("error", "style"),
|
|
Output("error", "children"),
|
|
Input("submit-button", "n_clicks"),
|
|
State("input-field", "value"),
|
|
State("tmp-value", "data"),
|
|
prevent_initial_call=True,
|
|
running=[
|
|
(Output("submit-button", "disabled"), True, False),
|
|
(
|
|
Output("submit-button", "style"),
|
|
{"opacity": 0.5, "margin-left": "20px"},
|
|
{"opacity": 1.0, "margin-left": "20px"},
|
|
),
|
|
],
|
|
)
|
|
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
|
|
"""Update the output based on user input and button clicks.
|
|
|
|
Parameters
|
|
----------
|
|
n_clicks : int
|
|
Number of times the submit button has been clicked.
|
|
value : str
|
|
Current value of the input field.
|
|
data : str
|
|
Previously stored value.
|
|
|
|
Returns
|
|
-------
|
|
Tuple[str, str, Dict[str, Any]]
|
|
Updated output text, new stored value, and error style.
|
|
"""
|
|
global err_style
|
|
db_connected = test_db_connection()
|
|
if n_clicks > 0 and value != data and db_connected:
|
|
result = send_message(value)
|
|
err_style["height"] = "0px"
|
|
|
|
# parse LLM response to dict, then try to execute the query
|
|
try:
|
|
parsed_result = json.loads(result, strict=False)
|
|
|
|
result_table = execute_query(parsed_result["query"])
|
|
children = [
|
|
html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]),
|
|
html.P([html.B("SQL Abfrage: "), f"{parsed_result['query']}"]),
|
|
render_table(result_table),
|
|
]
|
|
return children, value, err_style, html.P("")
|
|
except Exception as e:
|
|
err_style["height"] = "400px"
|
|
err_child = html.Div(f"Folgender Fehler ist aufgetreten: {e}.LLM Output: {result}.")
|
|
|
|
return no_update, value, err_style, err_child
|
|
|
|
elif not db_connected:
|
|
err_style["height"] = "50px"
|
|
err_child = html.P(
|
|
(
|
|
"Fehler beim Herstellen der Verbindung zur "
|
|
"Datenbank. Bitte versuche es später erneut."
|
|
)
|
|
)
|
|
return no_update, no_update, err_style, err_child
|
|
|
|
elif value == data:
|
|
|
|
err_style["height"] = "50px"
|
|
err_child = html.P("Bitte eine neue Anfrage eingeben.")
|
|
return no_update, no_update, err_style, err_child
|
|
|
|
raise PreventUpdate
|
|
|
|
|
|
@callback(
|
|
Output("sql-output", "children"),
|
|
Input("sql-submit-button", "n_clicks"),
|
|
State("sql-input-field", "value"),
|
|
prevent_initial_call=True,
|
|
)
|
|
def run_sql_query(n_clicks: int, value: str) -> str:
|
|
"""Run a SQL query and return the results.
|
|
|
|
Parameters
|
|
----------
|
|
n_clicks : int
|
|
Number of times the submit button has been clicked.
|
|
value : str
|
|
Current value of the input field.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The results of the SQL query.
|
|
"""
|
|
if n_clicks > 0:
|
|
result = execute_query(value)
|
|
if isinstance(result, str):
|
|
global err_style
|
|
tmp_style = err_style.copy()
|
|
tmp_style["height"] = "100px"
|
|
tmp_style["padding"] = "20px"
|
|
err_child = html.Div(
|
|
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
|
|
)
|
|
return err_child
|
|
else:
|
|
return render_table(result)
|
|
raise PreventUpdate
|
|
|
|
|
|
server = app.server
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=True)
|