371 lines
12 KiB
Python
371 lines
12 KiB
Python
import json
|
|
import os
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import dash_auth
|
|
import pandas as pd
|
|
from app_styles import header_style
|
|
from config import check_credentials
|
|
from dash import (
|
|
Dash,
|
|
Input,
|
|
Output,
|
|
State,
|
|
callback,
|
|
dash_table,
|
|
dcc,
|
|
get_asset_url,
|
|
html,
|
|
no_update,
|
|
)
|
|
from dash.exceptions import PreventUpdate
|
|
from data_chat import send_message
|
|
from sql_utils import execute_query, test_db_connection
|
|
|
|
check_credentials()
|
|
|
|
# first connection to SQL database to mitigate long startup time
|
|
try:
|
|
test_db_connection()
|
|
except Exception as e:
|
|
print(f"Error for first connection to Azure SQL Database: {e}")
|
|
|
|
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
|
|
|
app = Dash(__name__, external_stylesheets=external_stylesheets)
|
|
auth = dash_auth.BasicAuth(
|
|
app,
|
|
{os.getenv("APP_UNAME"): os.getenv("APP_PW")},
|
|
)
|
|
app.index_string = header_style
|
|
|
|
notification_md = """
|
|
**Hinweise:**
|
|
|
|
Aufgrund des sparsamen pricing Tiers kann es einige Sekunden dauern, bis die
|
|
Verbindung zur Datenbank hergestellt wird. Im Falle eines Fehlers oder langer Ladedauer
|
|
(> 2 min.) gern ein-zwei mal erneut versuchen (die Seite neu Laden). Sobald die Verbindung
|
|
einmal hergestellt wurde, geht es schnell.
|
|
|
|
GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt.
|
|
In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl
|
|
zusätzliche Informationen zu geben.
|
|
|
|
Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen.
|
|
|
|
Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen.
|
|
|
|
**Beispielfragen**:
|
|
- Wie viele Kunden haben wir in Hannover?
|
|
- Zeige alle Kunden in Bremen.
|
|
- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg.
|
|
- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben.
|
|
- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022
|
|
und 2023?
|
|
|
|
Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im
|
|
[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application).
|
|
"""
|
|
|
|
err_style = {
|
|
"height": "0px",
|
|
"overflow": "hidden",
|
|
"transition": "height 0.5s ease-in-out",
|
|
"border-radius": "15px",
|
|
"background-color": "#FFCCCB",
|
|
"text-align": "center",
|
|
"color": "#FF6B6B",
|
|
"margin-top": "20px",
|
|
"margin-left": "20px",
|
|
"margin-right": "20px",
|
|
"font-weight": "bold",
|
|
"display": "flex",
|
|
"justify-content": "center",
|
|
"align-items": "center",
|
|
}
|
|
|
|
|
|
def render_table(df: pd.DataFrame) -> dash_table.DataTable:
|
|
"""Create a Dash DataTable from a pandas DataFrame.
|
|
|
|
Parameters
|
|
----------
|
|
df : pd.DataFrame
|
|
The input DataFrame to be rendered as a table.
|
|
|
|
Returns
|
|
-------
|
|
dash_table.DataTable
|
|
A Dash DataTable component with styled layout and pagination.
|
|
"""
|
|
tab = dash_table.DataTable(
|
|
id="table",
|
|
columns=[{"name": i, "id": i} for i in df.columns],
|
|
data=df.to_dict("records"),
|
|
page_size=10,
|
|
style_table={
|
|
"overflowX": "auto",
|
|
"margin": "auto",
|
|
"width": "96%",
|
|
"margin-top": "20px",
|
|
},
|
|
style_cell={
|
|
"minWidth": "100px",
|
|
"width": "150px",
|
|
"maxWidth": "300px",
|
|
"overflow": "hidden",
|
|
"textOverflow": "ellipsis",
|
|
},
|
|
style_header={
|
|
"backgroundColor": "lightblue",
|
|
"fontWeight": "bold",
|
|
"color": "black",
|
|
},
|
|
)
|
|
return tab
|
|
|
|
|
|
def get_layout() -> html.Div:
|
|
"""Generate the layout for a Dash application.
|
|
|
|
This function creates a complex layout for a database chat interface
|
|
and a direct SQL query interface. It includes various Dash components
|
|
such as text areas, buttons, and a loading spinner.
|
|
|
|
The layout consists of:
|
|
- A header with title and logo
|
|
- A textarea for user input (database chat)
|
|
- A submit button for the database chat
|
|
- A Notification about the connection time and LLM performance
|
|
- An error message area
|
|
- A loading spinner and output area for database responses
|
|
- A section for direct SQL queries, including a textarea and submit button
|
|
- An output area for SQL query results
|
|
|
|
Returns
|
|
-------
|
|
html.Div
|
|
A Dash html.Div component containing the entire layout of the application.
|
|
"""
|
|
global err_style
|
|
tmp_style = err_style.copy()
|
|
tmp_style["height"] = "0px"
|
|
|
|
start_value = "Stelle deine Frage an die Datenbank..."
|
|
|
|
layout = html.Div(
|
|
[
|
|
html.Div(
|
|
[
|
|
html.H1("Datenbank-Chat", className="heading"),
|
|
html.Img(src=get_asset_url("logo.png"), className="logo"),
|
|
],
|
|
className="header-container",
|
|
), # Header
|
|
html.Div(
|
|
"Ganz ohne SQL-Kenntnisse Daten zu Zählerstandmessungen unserer Kunden abrufen!",
|
|
style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"},
|
|
),
|
|
dcc.Store(
|
|
id="tmp-value", data=start_value, storage_type="memory"
|
|
), # Store previous prompt
|
|
dcc.Textarea(
|
|
id="input-field",
|
|
value=start_value,
|
|
style={"width": "96%", "height": 200, "margin-left": "20px"},
|
|
), # Input field
|
|
html.Div([]), # Needed for keeping the layout clean
|
|
html.Button(
|
|
"Abschicken",
|
|
id="submit-button",
|
|
n_clicks=0,
|
|
disabled=False,
|
|
style={"margin-left": "20px"},
|
|
), # Submit button
|
|
dcc.Markdown(
|
|
notification_md,
|
|
style={
|
|
"margin-left": "20px",
|
|
"margin-top": "20px",
|
|
"margin-right": "10px",
|
|
"background-color": "#C7E6F5",
|
|
"border-radius": "5px",
|
|
"padding": "10px",
|
|
},
|
|
),
|
|
html.Div(
|
|
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
|
|
), # Error message (only visible if input is not updated but submit button is clicked)
|
|
dcc.Loading(
|
|
id="loading",
|
|
type="default",
|
|
children=[
|
|
html.Div(
|
|
"Hier erscheint die Antwort des KI-Modells.",
|
|
id="text-output",
|
|
style={
|
|
"whiteSpace": "pre-line",
|
|
"margin-top": "30px",
|
|
"margin-left": "20px",
|
|
"margin-right": "20px",
|
|
"border": "2px solid #86bc25",
|
|
"border-radius": "15px",
|
|
"padding": "20px",
|
|
},
|
|
)
|
|
],
|
|
),
|
|
html.H2(
|
|
"Direkte SQL-Abfrage",
|
|
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
|
|
), # SQL Header
|
|
html.Div(
|
|
(
|
|
"Hier kann der ausgegebene SQL-Code getestet oder mit selbst"
|
|
"geschriebenen Code verglichen werden."
|
|
),
|
|
style={"margin-left": "20px", "font-weight": "bold", "font-size": "16px"},
|
|
),
|
|
dcc.Textarea(
|
|
id="sql-input-field",
|
|
value="(Microsoft) SQL-Abfrage eingeben...",
|
|
style={"width": "96%", "height": 200, "margin-left": "20px"},
|
|
), # SQL Input field
|
|
html.Div([]), # Needed for keeping the layout clean
|
|
html.Button(
|
|
"Abschicken",
|
|
id="sql-submit-button",
|
|
n_clicks=0,
|
|
disabled=False,
|
|
style={"margin-left": "20px"},
|
|
), # Submit button
|
|
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
|
|
],
|
|
className="container",
|
|
)
|
|
|
|
return layout
|
|
|
|
|
|
app.layout = html.Div([get_layout()])
|
|
|
|
|
|
@callback(
|
|
Output("text-output", "children"),
|
|
Output("tmp-value", "data"),
|
|
Output("error", "style"),
|
|
Output("error", "children"),
|
|
Input("submit-button", "n_clicks"),
|
|
State("input-field", "value"),
|
|
State("tmp-value", "data"),
|
|
prevent_initial_call=True,
|
|
running=[
|
|
(Output("submit-button", "disabled"), True, False),
|
|
(
|
|
Output("submit-button", "style"),
|
|
{"opacity": 0.5, "margin-left": "20px"},
|
|
{"opacity": 1.0, "margin-left": "20px"},
|
|
),
|
|
],
|
|
)
|
|
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
|
|
"""Update the output based on user input and button clicks.
|
|
|
|
Parameters
|
|
----------
|
|
n_clicks : int
|
|
Number of times the submit button has been clicked.
|
|
value : str
|
|
Current value of the input field.
|
|
data : str
|
|
Previously stored value.
|
|
|
|
Returns
|
|
-------
|
|
Tuple[str, str, Dict[str, Any]]
|
|
Updated output text, new stored value, and error style.
|
|
"""
|
|
global err_style
|
|
db_connected = test_db_connection()
|
|
if n_clicks > 0 and value != data and db_connected:
|
|
result = send_message(value)
|
|
err_style["height"] = "0px"
|
|
|
|
# parse LLM response to dict, then try to execute the query
|
|
try:
|
|
parsed_result = json.loads(result, strict=False)
|
|
|
|
result_table = execute_query(parsed_result["query"])
|
|
children = [
|
|
html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]),
|
|
html.P([html.B("SQL Abfrage: "), f"{parsed_result['query']}"]),
|
|
render_table(result_table),
|
|
]
|
|
return children, value, err_style, html.P("")
|
|
except Exception as e:
|
|
err_style["height"] = "400px"
|
|
err_child = html.Div(f"Folgender Fehler ist aufgetreten: {e}.LLM Output: {result}.")
|
|
|
|
return no_update, value, err_style, err_child
|
|
|
|
elif not db_connected:
|
|
err_style["height"] = "50px"
|
|
err_child = html.P(
|
|
(
|
|
"Fehler beim Herstellen der Verbindung zur "
|
|
"Datenbank. Bitte versuche es später erneut."
|
|
)
|
|
)
|
|
return no_update, no_update, err_style, err_child
|
|
|
|
elif value == data:
|
|
|
|
err_style["height"] = "50px"
|
|
err_child = html.P("Bitte eine neue Anfrage eingeben.")
|
|
return no_update, no_update, err_style, err_child
|
|
|
|
raise PreventUpdate
|
|
|
|
|
|
@callback(
|
|
Output("sql-output", "children"),
|
|
Input("sql-submit-button", "n_clicks"),
|
|
State("sql-input-field", "value"),
|
|
prevent_initial_call=True,
|
|
)
|
|
def run_sql_query(n_clicks: int, value: str) -> str:
|
|
"""Run a SQL query and return the results.
|
|
|
|
Parameters
|
|
----------
|
|
n_clicks : int
|
|
Number of times the submit button has been clicked.
|
|
value : str
|
|
Current value of the input field.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
The results of the SQL query.
|
|
"""
|
|
if n_clicks > 0:
|
|
result = execute_query(value)
|
|
if isinstance(result, str):
|
|
global err_style
|
|
tmp_style = err_style.copy()
|
|
tmp_style["height"] = "100px"
|
|
tmp_style["padding"] = "20px"
|
|
err_child = html.Div(
|
|
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
|
|
)
|
|
return err_child
|
|
else:
|
|
return render_table(result)
|
|
raise PreventUpdate
|
|
|
|
|
|
server = app.server
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=True)
|