feat/ai-chat: Add core components for Database chat #5

Merged
quadfaselt merged 5 commits from feat/ai-chat into main 2024-09-03 13:38:24 +00:00
3 changed files with 189 additions and 54 deletions
Showing only changes of commit 94b5545173 - Show all commits

View File

@@ -1,5 +1,8 @@
import json
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import pandas as pd
from app_styles import header_style
from dash import ( from dash import (
Dash, Dash,
Input, Input,
@@ -13,16 +16,40 @@ from dash import (
no_update, no_update,
) )
from dash.exceptions import PreventUpdate from dash.exceptions import PreventUpdate
from data_chat import send_message
from .app_styles import header_style from sql_utils import execute_query, test_db_connection
from .data_chat import send_message
from .sql_utils import execute_query, test_db_connection
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"] external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets) app = Dash(__name__, external_stylesheets=external_stylesheets)
app.index_string = header_style app.index_string = header_style
notification_md = """
**Hinweise:**
Aufgrund des sparsamen pricing Tiers kann es einige Sekunden dauern, bis die
Verbindung zur Datenbank hergestellt wird. Im Falle eines Fehlers gern ein-zwei mal erneut
versuchen.
GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt.
In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl
zusätzliche Informationen zu geben.
Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen.
Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen.
**Beispielfragen**:
- Wie viele Kunden haben wir in Hannover?
- Zeige alle Kunden in Bremen.
- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg.
- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben.
- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022
und 2023?
Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im
[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application).
"""
err_style = { err_style = {
"height": "0px", "height": "0px",
@@ -42,6 +69,46 @@ err_style = {
} }
def render_table(df: pd.DataFrame) -> dash_table.DataTable:
"""Create a Dash DataTable from a pandas DataFrame.
Parameters
----------
df : pd.DataFrame
The input DataFrame to be rendered as a table.
Returns
-------
dash_table.DataTable
A Dash DataTable component with styled layout and pagination.
"""
tab = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in df.columns],
data=df.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
"margin-top": "20px",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return tab
def get_layout() -> html.Div: def get_layout() -> html.Div:
"""Generate the layout for a Dash application. """Generate the layout for a Dash application.
@@ -53,6 +120,7 @@ def get_layout() -> html.Div:
- A header with title and logo - A header with title and logo
- A textarea for user input (database chat) - A textarea for user input (database chat)
- A submit button for the database chat - A submit button for the database chat
- A Notification about the connection time and LLM performance
- An error message area - An error message area
- A loading spinner and output area for database responses - A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button - A section for direct SQL queries, including a textarea and submit button
@@ -78,6 +146,10 @@ def get_layout() -> html.Div:
], ],
className="header-container", className="header-container",
), # Header ), # Header
html.Div(
"Ganz ohne SQL-Kenntnisse Daten zu Zählerstandmessungen unserer Kunden abrufen!",
style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"},
),
dcc.Store( dcc.Store(
id="tmp-value", data=start_value, storage_type="memory" id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt ), # Store previous prompt
@@ -94,6 +166,17 @@ def get_layout() -> html.Div:
disabled=False, disabled=False,
style={"margin-left": "20px"}, style={"margin-left": "20px"},
), # Submit button ), # Submit button
dcc.Markdown(
notification_md,
style={
"margin-left": "20px",
"margin-top": "20px",
"margin-right": "10px",
"background-color": "#C7E6F5",
"border-radius": "5px",
"padding": "10px",
},
),
html.Div( html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style [html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked) ), # Error message (only visible if input is not updated but submit button is clicked)
@@ -102,7 +185,7 @@ def get_layout() -> html.Div:
type="default", type="default",
children=[ children=[
html.Div( html.Div(
"Hier erscheint die Antwort der Datenbank.", "Hier erscheint die Antwort des KI-Modells.",
id="text-output", id="text-output",
style={ style={
"whiteSpace": "pre-line", "whiteSpace": "pre-line",
@@ -180,18 +263,28 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[
Updated output text, new stored value, and error style. Updated output text, new stored value, and error style.
""" """
global err_style global err_style
print(f"Value: {value}")
print(f"Data: {data}")
db_connected = test_db_connection() db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected: if n_clicks > 0 and value != data and db_connected:
result = send_message(value) result = send_message(value)
err_style["height"] = "0px" err_style["height"] = "0px"
return result, value, err_style, html.P("")
elif value == data:
err_style["height"] = "50px" # parse LLM response to dict, then try to execute the query
err_child = html.P("Bitte eine neue Anfrage eingeben.") try:
return no_update, no_update, err_style, err_child parsed_result = json.loads(result, strict=False)
result_table = execute_query(parsed_result["query"])
children = [
html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]),
html.P([html.B("SQL Abfrage: "), f"{parsed_result['query']}"]),
render_table(result_table),
]
return children, value, err_style, html.P("")
except Exception as e:
err_style["height"] = "400px"
err_child = html.Div(f"Folgender Fehler ist aufgetreten: {e}.LLM Output: {result}.")
return no_update, value, err_style, err_child
elif not db_connected: elif not db_connected:
err_style["height"] = "50px" err_style["height"] = "50px"
err_child = html.P( err_child = html.P(
@@ -202,6 +295,12 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[
) )
return no_update, no_update, err_style, err_child return no_update, no_update, err_style, err_child
elif value == data:
err_style["height"] = "50px"
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
raise PreventUpdate raise PreventUpdate
@@ -231,37 +330,14 @@ def run_sql_query(n_clicks: int, value: str) -> str:
if isinstance(result, str): if isinstance(result, str):
global err_style global err_style
tmp_style = err_style.copy() tmp_style = err_style.copy()
tmp_style["height"] = "80px" tmp_style["height"] = "100px"
tmp_style["padding"] = "20px" tmp_style["padding"] = "20px"
err_child = html.Div( err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style [html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
) )
return err_child return err_child
else: else:
table_child = dash_table.DataTable( return render_table(result)
id="table",
columns=[{"name": i, "id": i} for i in result.columns],
data=result.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return table_child
raise PreventUpdate raise PreventUpdate

View File

@@ -1,16 +1,24 @@
import os import os
from openai import AzureOpenAI from openai import OpenAI
# from openai import AzureOpenAI
# Set up credentials # Set up credentials
# NOTE: When running locally, these have to be set in the environment # NOTE: Usually I would use AzureOpenAI, but due to heavy rate
client = AzureOpenAI( # limitations on azure trial accounts, I am using OpenAI directly
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # for this project. However, this is how it would look like for
api_key=os.getenv("AZURE_OPENAI_KEY"), # AzureOpenAI (credentials must be provided to environment):
api_version="2024-02-01", # client = AzureOpenAI(
) # azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
# api_key=os.getenv("AZURE_OPENAI_KEY"),
# api_version="2024-02-01",
# )
# MODEL = "sqlai" # deployment name
deployment_name = "sqlai" # Set up the OpenAI client
MODEL = "gpt-4o"
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def send_message(message: str) -> str: def send_message(message: str) -> str:
@@ -28,15 +36,60 @@ def send_message(message: str) -> str:
""" """
system_message = """ system_message = """
Du bist ein hilfsbereiter, fröhlicher Datenbankassistent. Du bist ein hilfsbereiter, fröhlicher Datenbankassistent.
Du hilfst Benutzern bei der Erstellung von SQL-Abfragen für eine Datenbank eines
großen Energieversorgungsunternehmens. Die Datenbank enthält Tabellen für Adressen,
Zähler, Kunden und Ablesungen. Es werden Gaszähler (MeterType 'GAS') und Stromzähler
(MeterType 'ELT')unterschieden.
Besonders wichtig ist, dass die Ablesungen der Werte kumulativ sind. Wenn nach dem Verbrauch
gefragt wird, sollte der Unterschied zwischen zwei aufeinanderfolgenden Ablesungen berechnet
werden.
Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema: Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema:
MEIN_DATENBANKSCHEMA CREATE TABLE Addresses (
ID INT PRIMARY KEY IDENTITY(1,1),
StreetName NVARCHAR(100),
HouseNumber NVARCHAR(10),
City NVARCHAR(50),
PostalCode NVARCHAR(10),
Longitude FLOAT,
Latitude FLOAT
);
CREATE TABLE Meters (
ID INT PRIMARY KEY IDENTITY(1,1),
Signature NVARCHAR(11),
MeterType NVARCHAR(3),
AddressID INT,
FOREIGN KEY (AddressID) REFERENCES Addresses(ID)
);
CREATE TABLE Customers (
ID INT PRIMARY KEY IDENTITY(1,1),
FirstName NVARCHAR(100),
LastName NVARCHAR(100),
GasMeterID INT,
EltMeterID INT,
FOREIGN KEY (GasMeterID) REFERENCES Meters(ID),
FOREIGN KEY (EltMeterID) REFERENCES Meters(ID)
);
CREATE TABLE Readings (
ID INT PRIMARY KEY IDENTITY(1,1),
CustomerID INT,
MeterID INT,
ReadingDate DATE,
ReadingValue INT,
FOREIGN KEY (CustomerID) REFERENCES Customers(ID),
FOREIGN KEY (MeterID) REFERENCES Meters(ID)
);
Füge Spaltenüberschriften in die Abfrageergebnisse ein. Füge Spaltenüberschriften in die Abfrageergebnisse ein.
Gib deine Antwort immer im folgenden JSON-Format an: Gib deine Antwort immer im folgenden JSON-Format an:
JSON FORMAT { "summary": "your-summary", "query": "your-query" }
Gib NUR JSON aus. Gib NUR JSON aus.
Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query, Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query,
@@ -45,15 +98,20 @@ def send_message(message: str) -> str:
Gib immer alle Spalten der Tabelle an. Gib immer alle Spalten der Tabelle an.
Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze
trotzdem "your-query" durch eine Zusammenfassung der Abfrage. trotzdem "your-query" durch eine Zusammenfassung der Abfrage.
Verwende KEINE MySQL-Syntax. Verwende KEINE MySQL-Syntax, sondern AUSSCHLIESSLICH Microsoft SQL.
Begrenze die SQL-Abfrage immer auf 100 Zeilen. Begrenze die SQL-Abfrage immer auf 100 Zeilen.
Formatiere den Output bestmöglich.
""" """
system_message = "Du bist ein hilfreicher Assistent."
response = client.chat.completions.create( response = client.chat.completions.create(
model=deployment_name, model=MODEL,
messages=[ messages=[
{"role": "system", "content": system_message}, {"role": "system", "content": system_message},
{"role": "user", "content": message}, {"role": "user", "content": message},
], ],
) )
return response.choices[0].message.content
result_str = response.choices[0].message.content.replace("```json\n", "").replace("```", "")
if ("\n") not in result_str:
result_str = result_str.replace("\\", "\n")
return result_str

View File

@@ -11,7 +11,7 @@ def test_db_connection() -> bool:
This function attempts to establish a connection to an Azure SQL Database This function attempts to establish a connection to an Azure SQL Database
using the connection string stored in the environment variable using the connection string stored in the environment variable
'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect, 'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect,
with a timeout of 240 seconds for each attempt. with a timeout of 480 seconds for each attempt.
Returns Returns
------- -------
@@ -22,14 +22,15 @@ def test_db_connection() -> bool:
for i in range(5): for i in range(5):
print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}") print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}")
try: try:
pyodbc.connect(connection_string, timeout=240) pyodbc.connect(connection_string, timeout=480)
print("Connected to Azure SQL Database successfully!") print("Connected to Azure SQL Database successfully!")
connected = True connected = True
break break
except pyodbc.Error as e: except Exception as e:
print(f"Error connecting to Azure SQL Database: {e}") print(f"Error connecting to Azure SQL Database: {e}")
connected = False connected = False
continue
return connected return connected