feat(ai-chat): Add code logic for AI-based data chat

Add the first working code logic both in terms of backend and
frontend-related tasks. Add a detailled system message for improved
results. Add several UI improvements for result display and user
information. Add text input field for direct SQL code comparison.

The implementation of the openAI backend had to be changed due to strict
rate limits of azure OpenAI free tier and was replaced with a regular
openai API key.
This commit is contained in:
Tobias Quadfasel
2024-09-03 14:48:38 +02:00
parent 4b9fa0579e
commit 94b5545173
3 changed files with 189 additions and 54 deletions

View File

@@ -1,5 +1,8 @@
import json
from typing import Any, Dict, Tuple
import pandas as pd
from app_styles import header_style
from dash import (
Dash,
Input,
@@ -13,16 +16,40 @@ from dash import (
no_update,
)
from dash.exceptions import PreventUpdate
from .app_styles import header_style
from .data_chat import send_message
from .sql_utils import execute_query, test_db_connection
from data_chat import send_message
from sql_utils import execute_query, test_db_connection
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets)
app.index_string = header_style
notification_md = """
**Hinweise:**
Aufgrund des sparsamen pricing Tiers kann es einige Sekunden dauern, bis die
Verbindung zur Datenbank hergestellt wird. Im Falle eines Fehlers gern ein-zwei mal erneut
versuchen.
GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt.
In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl
zusätzliche Informationen zu geben.
Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen.
Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen.
**Beispielfragen**:
- Wie viele Kunden haben wir in Hannover?
- Zeige alle Kunden in Bremen.
- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg.
- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben.
- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022
und 2023?
Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im
[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application).
"""
err_style = {
"height": "0px",
@@ -42,6 +69,46 @@ err_style = {
}
def render_table(df: pd.DataFrame) -> dash_table.DataTable:
"""Create a Dash DataTable from a pandas DataFrame.
Parameters
----------
df : pd.DataFrame
The input DataFrame to be rendered as a table.
Returns
-------
dash_table.DataTable
A Dash DataTable component with styled layout and pagination.
"""
tab = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in df.columns],
data=df.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
"margin-top": "20px",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return tab
def get_layout() -> html.Div:
"""Generate the layout for a Dash application.
@@ -53,6 +120,7 @@ def get_layout() -> html.Div:
- A header with title and logo
- A textarea for user input (database chat)
- A submit button for the database chat
- A Notification about the connection time and LLM performance
- An error message area
- A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button
@@ -78,6 +146,10 @@ def get_layout() -> html.Div:
],
className="header-container",
), # Header
html.Div(
"Ganz ohne SQL-Kenntnisse Daten zu Zählerstandmessungen unserer Kunden abrufen!",
style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"},
),
dcc.Store(
id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt
@@ -94,6 +166,17 @@ def get_layout() -> html.Div:
disabled=False,
style={"margin-left": "20px"},
), # Submit button
dcc.Markdown(
notification_md,
style={
"margin-left": "20px",
"margin-top": "20px",
"margin-right": "10px",
"background-color": "#C7E6F5",
"border-radius": "5px",
"padding": "10px",
},
),
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked)
@@ -102,7 +185,7 @@ def get_layout() -> html.Div:
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
"Hier erscheint die Antwort des KI-Modells.",
id="text-output",
style={
"whiteSpace": "pre-line",
@@ -180,18 +263,28 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[
Updated output text, new stored value, and error style.
"""
global err_style
print(f"Value: {value}")
print(f"Data: {data}")
db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected:
result = send_message(value)
err_style["height"] = "0px"
return result, value, err_style, html.P("")
elif value == data:
err_style["height"] = "50px"
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
# parse LLM response to dict, then try to execute the query
try:
parsed_result = json.loads(result, strict=False)
result_table = execute_query(parsed_result["query"])
children = [
html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]),
html.P([html.B("SQL Abfrage: "), f"{parsed_result['query']}"]),
render_table(result_table),
]
return children, value, err_style, html.P("")
except Exception as e:
err_style["height"] = "400px"
err_child = html.Div(f"Folgender Fehler ist aufgetreten: {e}.LLM Output: {result}.")
return no_update, value, err_style, err_child
elif not db_connected:
err_style["height"] = "50px"
err_child = html.P(
@@ -202,6 +295,12 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[
)
return no_update, no_update, err_style, err_child
elif value == data:
err_style["height"] = "50px"
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
raise PreventUpdate
@@ -231,37 +330,14 @@ def run_sql_query(n_clicks: int, value: str) -> str:
if isinstance(result, str):
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "80px"
tmp_style["height"] = "100px"
tmp_style["padding"] = "20px"
err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
)
return err_child
else:
table_child = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in result.columns],
data=result.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return table_child
return render_table(result)
raise PreventUpdate