Files
grid_application/app/app.py
Tobias Quadfasel 5cec810947 fix(sql-formatting): Fix SQL code formatting
Fixed SQL code formatting errors by:
- catching both single and double backslashes in the formatting
- explicitly telling LLM how to format linebreaks

Also did some changes to the UI and allowed general questions
about the database content to be asked.
2024-10-06 11:59:45 +02:00

383 lines
12 KiB
Python

import json
import os
from typing import Any, Dict, Tuple
import dash_auth
import pandas as pd
from app_styles import header_style
from config import check_credentials
from dash import (
Dash,
Input,
Output,
State,
callback,
dash_table,
dcc,
get_asset_url,
html,
no_update,
)
from dash.exceptions import PreventUpdate
from data_chat import send_message
from sql_utils import execute_query, test_db_connection
check_credentials()
# first connection to SQL database to mitigate long startup time
try:
test_db_connection()
except Exception as e:
print(f"Error for first connection to Azure SQL Database: {e}")
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets)
auth = dash_auth.BasicAuth(
app,
{os.getenv("APP_UNAME"): os.getenv("APP_PW")},
)
app.index_string = header_style
notification_md = """
**Hinweise:**
GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt.
In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl
zusätzliche Informationen zu geben.
Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen.
Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen.
Es können sowohl allgemeine Fragen gestellt werden, als auch Fragen, die eine SQL-Abfrage erfordern.
Beispiel für allgemeine Frage: 'Nenne mir alle Tabellen in der Datenbank, sowie die entsprechenden
Spalten und eine kurze Erklärung über deren Inhalt.'
**SQL-Beispielfragen**:
- Wie viele Kunden haben wir in Hannover?
- Zeige alle Kunden in Bremen.
- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg.
- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben.
- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022
und 2023?
Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im
[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application).
"""
err_style = {
"height": "0px",
"overflow": "hidden",
"transition": "height 0.5s ease-in-out",
"border-radius": "15px",
"background-color": "#FFCCCB",
"text-align": "center",
"color": "#FF6B6B",
"margin-top": "20px",
"margin-left": "20px",
"margin-right": "20px",
"font-weight": "bold",
"display": "flex",
"justify-content": "center",
"align-items": "center",
}
def render_table(df: pd.DataFrame) -> dash_table.DataTable:
"""Create a Dash DataTable from a pandas DataFrame.
Parameters
----------
df : pd.DataFrame
The input DataFrame to be rendered as a table.
Returns
-------
dash_table.DataTable
A Dash DataTable component with styled layout and pagination.
"""
tab = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in df.columns],
data=df.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
"margin-top": "20px",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return tab
def get_layout() -> html.Div:
"""Generate the layout for a Dash application.
This function creates a complex layout for a database chat interface
and a direct SQL query interface. It includes various Dash components
such as text areas, buttons, and a loading spinner.
The layout consists of:
- A header with title and logo
- A textarea for user input (database chat)
- A submit button for the database chat
- A Notification about the connection time and LLM performance
- An error message area
- A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button
- An output area for SQL query results
Returns
-------
html.Div
A Dash html.Div component containing the entire layout of the application.
"""
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "0px"
start_value = "Stelle deine Frage an die Datenbank..."
layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
html.Div(
(
"Chatte mit unserer SQL-Datenbank, die Daten zu Zählerstandmessungen der "
"KundInnen enthält!"
),
style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"},
),
dcc.Store(
id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
dcc.Markdown(
notification_md,
style={
"margin-left": "20px",
"margin-top": "20px",
"margin-right": "10px",
"background-color": "#C7E6F5",
"border-radius": "5px",
"padding": "10px",
},
),
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort des KI-Modells.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
html.H2(
"Direkte SQL-Abfrage",
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
), # SQL Header
html.Div(
(
"Hier kann der ausgegebene SQL-Code getestet oder mit selbst"
"geschriebenem Code verglichen werden."
),
style={"margin-left": "20px", "font-weight": "bold", "font-size": "16px"},
),
dcc.Textarea(
id="sql-input-field",
value="(Microsoft) SQL-Abfrage eingeben...",
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # SQL Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="sql-submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
],
className="container",
)
return layout
app.layout = html.Div([get_layout()])
@callback(
Output("text-output", "children"),
Output("tmp-value", "data"),
Output("error", "style"),
Output("error", "children"),
Input("submit-button", "n_clicks"),
State("input-field", "value"),
State("tmp-value", "data"),
prevent_initial_call=True,
running=[
(Output("submit-button", "disabled"), True, False),
(
Output("submit-button", "style"),
{"opacity": 0.5, "margin-left": "20px"},
{"opacity": 1.0, "margin-left": "20px"},
),
],
)
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
"""Update the output based on user input and button clicks.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
data : str
Previously stored value.
Returns
-------
Tuple[str, str, Dict[str, Any]]
Updated output text, new stored value, and error style.
"""
global err_style
db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected:
result = send_message(value)
err_style["height"] = "0px"
# parse LLM response to dict, then try to execute the query
try:
parsed_result = json.loads(result, strict=False)
if parsed_result["query"] == "NA":
children = [
html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]),
]
else:
result_table = execute_query(parsed_result["query"])
children = [
html.P([html.B("Zusammenfassung:\n"), f"{parsed_result['summary']}"]),
html.P([html.B("SQL Abfrage:\n"), f"{parsed_result['query']}"]),
render_table(result_table),
]
return children, value, err_style, html.P("")
except Exception:
err_style["height"] = "50px"
err_child = html.Div(
(
"Ein Fehler ist aufgetreten. Versuchen Sie, "
"die Anfrage genauer zu beschreiben und versuchen Sie es erneut."
)
)
return no_update, value, err_style, err_child
elif not db_connected:
err_style["height"] = "50px"
err_child = html.P(
(
"Fehler beim Herstellen der Verbindung zur "
"Datenbank. Bitte versuche es später erneut."
)
)
return no_update, no_update, err_style, err_child
elif value == data:
err_style["height"] = "50px"
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
raise PreventUpdate
@callback(
Output("sql-output", "children"),
Input("sql-submit-button", "n_clicks"),
State("sql-input-field", "value"),
prevent_initial_call=True,
)
def run_sql_query(n_clicks: int, value: str) -> str:
"""Run a SQL query and return the results.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
Returns
-------
str
The results of the SQL query.
"""
if n_clicks > 0:
result = execute_query(value)
if isinstance(result, str):
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "100px"
tmp_style["padding"] = "20px"
err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
)
return err_child
else:
return render_table(result)
raise PreventUpdate
server = app.server
if __name__ == "__main__":
app.run(debug=True)