feat/ai-chat: Add core components for Database chat #5
150
app/app.py
150
app/app.py
@@ -1,5 +1,8 @@
|
|||||||
|
import json
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from app_styles import header_style
|
||||||
from dash import (
|
from dash import (
|
||||||
Dash,
|
Dash,
|
||||||
Input,
|
Input,
|
||||||
@@ -13,16 +16,40 @@ from dash import (
|
|||||||
no_update,
|
no_update,
|
||||||
)
|
)
|
||||||
from dash.exceptions import PreventUpdate
|
from dash.exceptions import PreventUpdate
|
||||||
|
from data_chat import send_message
|
||||||
from .app_styles import header_style
|
from sql_utils import execute_query, test_db_connection
|
||||||
from .data_chat import send_message
|
|
||||||
from .sql_utils import execute_query, test_db_connection
|
|
||||||
|
|
||||||
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
||||||
|
|
||||||
app = Dash(__name__, external_stylesheets=external_stylesheets)
|
app = Dash(__name__, external_stylesheets=external_stylesheets)
|
||||||
app.index_string = header_style
|
app.index_string = header_style
|
||||||
|
|
||||||
|
notification_md = """
|
||||||
|
**Hinweise:**
|
||||||
|
|
||||||
|
Aufgrund des sparsamen pricing Tiers kann es einige Sekunden dauern, bis die
|
||||||
|
Verbindung zur Datenbank hergestellt wird. Im Falle eines Fehlers gern ein-zwei mal erneut
|
||||||
|
versuchen.
|
||||||
|
|
||||||
|
GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt.
|
||||||
|
In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl
|
||||||
|
zusätzliche Informationen zu geben.
|
||||||
|
|
||||||
|
Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen.
|
||||||
|
|
||||||
|
Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen.
|
||||||
|
|
||||||
|
**Beispielfragen**:
|
||||||
|
- Wie viele Kunden haben wir in Hannover?
|
||||||
|
- Zeige alle Kunden in Bremen.
|
||||||
|
- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg.
|
||||||
|
- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben.
|
||||||
|
- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022
|
||||||
|
und 2023?
|
||||||
|
|
||||||
|
Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im
|
||||||
|
[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application).
|
||||||
|
"""
|
||||||
|
|
||||||
err_style = {
|
err_style = {
|
||||||
"height": "0px",
|
"height": "0px",
|
||||||
@@ -42,6 +69,46 @@ err_style = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def render_table(df: pd.DataFrame) -> dash_table.DataTable:
|
||||||
|
"""Create a Dash DataTable from a pandas DataFrame.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
df : pd.DataFrame
|
||||||
|
The input DataFrame to be rendered as a table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dash_table.DataTable
|
||||||
|
A Dash DataTable component with styled layout and pagination.
|
||||||
|
"""
|
||||||
|
tab = dash_table.DataTable(
|
||||||
|
id="table",
|
||||||
|
columns=[{"name": i, "id": i} for i in df.columns],
|
||||||
|
data=df.to_dict("records"),
|
||||||
|
page_size=10,
|
||||||
|
style_table={
|
||||||
|
"overflowX": "auto",
|
||||||
|
"margin": "auto",
|
||||||
|
"width": "96%",
|
||||||
|
"margin-top": "20px",
|
||||||
|
},
|
||||||
|
style_cell={
|
||||||
|
"minWidth": "100px",
|
||||||
|
"width": "150px",
|
||||||
|
"maxWidth": "300px",
|
||||||
|
"overflow": "hidden",
|
||||||
|
"textOverflow": "ellipsis",
|
||||||
|
},
|
||||||
|
style_header={
|
||||||
|
"backgroundColor": "lightblue",
|
||||||
|
"fontWeight": "bold",
|
||||||
|
"color": "black",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return tab
|
||||||
|
|
||||||
|
|
||||||
def get_layout() -> html.Div:
|
def get_layout() -> html.Div:
|
||||||
"""Generate the layout for a Dash application.
|
"""Generate the layout for a Dash application.
|
||||||
|
|
||||||
@@ -53,6 +120,7 @@ def get_layout() -> html.Div:
|
|||||||
- A header with title and logo
|
- A header with title and logo
|
||||||
- A textarea for user input (database chat)
|
- A textarea for user input (database chat)
|
||||||
- A submit button for the database chat
|
- A submit button for the database chat
|
||||||
|
- A Notification about the connection time and LLM performance
|
||||||
- An error message area
|
- An error message area
|
||||||
- A loading spinner and output area for database responses
|
- A loading spinner and output area for database responses
|
||||||
- A section for direct SQL queries, including a textarea and submit button
|
- A section for direct SQL queries, including a textarea and submit button
|
||||||
@@ -78,6 +146,10 @@ def get_layout() -> html.Div:
|
|||||||
],
|
],
|
||||||
className="header-container",
|
className="header-container",
|
||||||
), # Header
|
), # Header
|
||||||
|
html.Div(
|
||||||
|
"Ganz ohne SQL-Kenntnisse Daten zu Zählerstandmessungen unserer Kunden abrufen!",
|
||||||
|
style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"},
|
||||||
|
),
|
||||||
dcc.Store(
|
dcc.Store(
|
||||||
id="tmp-value", data=start_value, storage_type="memory"
|
id="tmp-value", data=start_value, storage_type="memory"
|
||||||
), # Store previous prompt
|
), # Store previous prompt
|
||||||
@@ -94,6 +166,17 @@ def get_layout() -> html.Div:
|
|||||||
disabled=False,
|
disabled=False,
|
||||||
style={"margin-left": "20px"},
|
style={"margin-left": "20px"},
|
||||||
), # Submit button
|
), # Submit button
|
||||||
|
dcc.Markdown(
|
||||||
|
notification_md,
|
||||||
|
style={
|
||||||
|
"margin-left": "20px",
|
||||||
|
"margin-top": "20px",
|
||||||
|
"margin-right": "10px",
|
||||||
|
"background-color": "#C7E6F5",
|
||||||
|
"border-radius": "5px",
|
||||||
|
"padding": "10px",
|
||||||
|
},
|
||||||
|
),
|
||||||
html.Div(
|
html.Div(
|
||||||
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
|
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
|
||||||
), # Error message (only visible if input is not updated but submit button is clicked)
|
), # Error message (only visible if input is not updated but submit button is clicked)
|
||||||
@@ -102,7 +185,7 @@ def get_layout() -> html.Div:
|
|||||||
type="default",
|
type="default",
|
||||||
children=[
|
children=[
|
||||||
html.Div(
|
html.Div(
|
||||||
"Hier erscheint die Antwort der Datenbank.",
|
"Hier erscheint die Antwort des KI-Modells.",
|
||||||
id="text-output",
|
id="text-output",
|
||||||
style={
|
style={
|
||||||
"whiteSpace": "pre-line",
|
"whiteSpace": "pre-line",
|
||||||
@@ -180,18 +263,28 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[
|
|||||||
Updated output text, new stored value, and error style.
|
Updated output text, new stored value, and error style.
|
||||||
"""
|
"""
|
||||||
global err_style
|
global err_style
|
||||||
print(f"Value: {value}")
|
|
||||||
print(f"Data: {data}")
|
|
||||||
db_connected = test_db_connection()
|
db_connected = test_db_connection()
|
||||||
if n_clicks > 0 and value != data and db_connected:
|
if n_clicks > 0 and value != data and db_connected:
|
||||||
result = send_message(value)
|
result = send_message(value)
|
||||||
err_style["height"] = "0px"
|
err_style["height"] = "0px"
|
||||||
return result, value, err_style, html.P("")
|
|
||||||
elif value == data:
|
|
||||||
|
|
||||||
err_style["height"] = "50px"
|
# parse LLM response to dict, then try to execute the query
|
||||||
err_child = html.P("Bitte eine neue Anfrage eingeben.")
|
try:
|
||||||
return no_update, no_update, err_style, err_child
|
parsed_result = json.loads(result, strict=False)
|
||||||
|
|
||||||
|
result_table = execute_query(parsed_result["query"])
|
||||||
|
children = [
|
||||||
|
html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]),
|
||||||
|
html.P([html.B("SQL Abfrage: "), f"{parsed_result['query']}"]),
|
||||||
|
render_table(result_table),
|
||||||
|
]
|
||||||
|
return children, value, err_style, html.P("")
|
||||||
|
except Exception as e:
|
||||||
|
err_style["height"] = "400px"
|
||||||
|
err_child = html.Div(f"Folgender Fehler ist aufgetreten: {e}.LLM Output: {result}.")
|
||||||
|
|
||||||
|
return no_update, value, err_style, err_child
|
||||||
|
|
||||||
elif not db_connected:
|
elif not db_connected:
|
||||||
err_style["height"] = "50px"
|
err_style["height"] = "50px"
|
||||||
err_child = html.P(
|
err_child = html.P(
|
||||||
@@ -202,6 +295,12 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[
|
|||||||
)
|
)
|
||||||
return no_update, no_update, err_style, err_child
|
return no_update, no_update, err_style, err_child
|
||||||
|
|
||||||
|
elif value == data:
|
||||||
|
|
||||||
|
err_style["height"] = "50px"
|
||||||
|
err_child = html.P("Bitte eine neue Anfrage eingeben.")
|
||||||
|
return no_update, no_update, err_style, err_child
|
||||||
|
|
||||||
raise PreventUpdate
|
raise PreventUpdate
|
||||||
|
|
||||||
|
|
||||||
@@ -231,37 +330,14 @@ def run_sql_query(n_clicks: int, value: str) -> str:
|
|||||||
if isinstance(result, str):
|
if isinstance(result, str):
|
||||||
global err_style
|
global err_style
|
||||||
tmp_style = err_style.copy()
|
tmp_style = err_style.copy()
|
||||||
tmp_style["height"] = "80px"
|
tmp_style["height"] = "100px"
|
||||||
tmp_style["padding"] = "20px"
|
tmp_style["padding"] = "20px"
|
||||||
err_child = html.Div(
|
err_child = html.Div(
|
||||||
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
|
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
|
||||||
)
|
)
|
||||||
return err_child
|
return err_child
|
||||||
else:
|
else:
|
||||||
table_child = dash_table.DataTable(
|
return render_table(result)
|
||||||
id="table",
|
|
||||||
columns=[{"name": i, "id": i} for i in result.columns],
|
|
||||||
data=result.to_dict("records"),
|
|
||||||
page_size=10,
|
|
||||||
style_table={
|
|
||||||
"overflowX": "auto",
|
|
||||||
"margin": "auto",
|
|
||||||
"width": "96%",
|
|
||||||
},
|
|
||||||
style_cell={
|
|
||||||
"minWidth": "100px",
|
|
||||||
"width": "150px",
|
|
||||||
"maxWidth": "300px",
|
|
||||||
"overflow": "hidden",
|
|
||||||
"textOverflow": "ellipsis",
|
|
||||||
},
|
|
||||||
style_header={
|
|
||||||
"backgroundColor": "lightblue",
|
|
||||||
"fontWeight": "bold",
|
|
||||||
"color": "black",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return table_child
|
|
||||||
raise PreventUpdate
|
raise PreventUpdate
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,24 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from openai import AzureOpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# from openai import AzureOpenAI
|
||||||
|
|
||||||
# Set up credentials
|
# Set up credentials
|
||||||
# NOTE: When running locally, these have to be set in the environment
|
# NOTE: Usually I would use AzureOpenAI, but due to heavy rate
|
||||||
client = AzureOpenAI(
|
# limitations on azure trial accounts, I am using OpenAI directly
|
||||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
# for this project. However, this is how it would look like for
|
||||||
api_key=os.getenv("AZURE_OPENAI_KEY"),
|
# AzureOpenAI (credentials must be provided to environment):
|
||||||
api_version="2024-02-01",
|
# client = AzureOpenAI(
|
||||||
)
|
# azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
|
# api_key=os.getenv("AZURE_OPENAI_KEY"),
|
||||||
|
# api_version="2024-02-01",
|
||||||
|
# )
|
||||||
|
# MODEL = "sqlai" # deployment name
|
||||||
|
|
||||||
deployment_name = "sqlai"
|
# Set up the OpenAI client
|
||||||
|
MODEL = "gpt-4o"
|
||||||
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
|
||||||
|
|
||||||
def send_message(message: str) -> str:
|
def send_message(message: str) -> str:
|
||||||
@@ -28,15 +36,60 @@ def send_message(message: str) -> str:
|
|||||||
"""
|
"""
|
||||||
system_message = """
|
system_message = """
|
||||||
Du bist ein hilfsbereiter, fröhlicher Datenbankassistent.
|
Du bist ein hilfsbereiter, fröhlicher Datenbankassistent.
|
||||||
|
Du hilfst Benutzern bei der Erstellung von SQL-Abfragen für eine Datenbank eines
|
||||||
|
großen Energieversorgungsunternehmens. Die Datenbank enthält Tabellen für Adressen,
|
||||||
|
Zähler, Kunden und Ablesungen. Es werden Gaszähler (MeterType 'GAS') und Stromzähler
|
||||||
|
(MeterType 'ELT')unterschieden.
|
||||||
|
|
||||||
|
Besonders wichtig ist, dass die Ablesungen der Werte kumulativ sind. Wenn nach dem Verbrauch
|
||||||
|
gefragt wird, sollte der Unterschied zwischen zwei aufeinanderfolgenden Ablesungen berechnet
|
||||||
|
werden.
|
||||||
|
|
||||||
Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema:
|
Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema:
|
||||||
|
|
||||||
MEIN_DATENBANKSCHEMA
|
CREATE TABLE Addresses (
|
||||||
|
ID INT PRIMARY KEY IDENTITY(1,1),
|
||||||
|
StreetName NVARCHAR(100),
|
||||||
|
HouseNumber NVARCHAR(10),
|
||||||
|
City NVARCHAR(50),
|
||||||
|
PostalCode NVARCHAR(10),
|
||||||
|
Longitude FLOAT,
|
||||||
|
Latitude FLOAT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE Meters (
|
||||||
|
ID INT PRIMARY KEY IDENTITY(1,1),
|
||||||
|
Signature NVARCHAR(11),
|
||||||
|
MeterType NVARCHAR(3),
|
||||||
|
AddressID INT,
|
||||||
|
FOREIGN KEY (AddressID) REFERENCES Addresses(ID)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE Customers (
|
||||||
|
ID INT PRIMARY KEY IDENTITY(1,1),
|
||||||
|
FirstName NVARCHAR(100),
|
||||||
|
LastName NVARCHAR(100),
|
||||||
|
GasMeterID INT,
|
||||||
|
EltMeterID INT,
|
||||||
|
FOREIGN KEY (GasMeterID) REFERENCES Meters(ID),
|
||||||
|
FOREIGN KEY (EltMeterID) REFERENCES Meters(ID)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE Readings (
|
||||||
|
ID INT PRIMARY KEY IDENTITY(1,1),
|
||||||
|
CustomerID INT,
|
||||||
|
MeterID INT,
|
||||||
|
ReadingDate DATE,
|
||||||
|
ReadingValue INT,
|
||||||
|
FOREIGN KEY (CustomerID) REFERENCES Customers(ID),
|
||||||
|
FOREIGN KEY (MeterID) REFERENCES Meters(ID)
|
||||||
|
);
|
||||||
|
|
||||||
Füge Spaltenüberschriften in die Abfrageergebnisse ein.
|
Füge Spaltenüberschriften in die Abfrageergebnisse ein.
|
||||||
|
|
||||||
Gib deine Antwort immer im folgenden JSON-Format an:
|
Gib deine Antwort immer im folgenden JSON-Format an:
|
||||||
|
|
||||||
JSON FORMAT
|
{ "summary": "your-summary", "query": "your-query" }
|
||||||
|
|
||||||
Gib NUR JSON aus.
|
Gib NUR JSON aus.
|
||||||
Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query,
|
Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query,
|
||||||
@@ -45,15 +98,20 @@ def send_message(message: str) -> str:
|
|||||||
Gib immer alle Spalten der Tabelle an.
|
Gib immer alle Spalten der Tabelle an.
|
||||||
Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze
|
Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze
|
||||||
trotzdem "your-query" durch eine Zusammenfassung der Abfrage.
|
trotzdem "your-query" durch eine Zusammenfassung der Abfrage.
|
||||||
Verwende KEINE MySQL-Syntax.
|
Verwende KEINE MySQL-Syntax, sondern AUSSCHLIESSLICH Microsoft SQL.
|
||||||
Begrenze die SQL-Abfrage immer auf 100 Zeilen.
|
Begrenze die SQL-Abfrage immer auf 100 Zeilen.
|
||||||
|
Formatiere den Output bestmöglich.
|
||||||
"""
|
"""
|
||||||
system_message = "Du bist ein hilfreicher Assistent."
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=deployment_name,
|
model=MODEL,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": system_message},
|
{"role": "system", "content": system_message},
|
||||||
{"role": "user", "content": message},
|
{"role": "user", "content": message},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
|
||||||
|
result_str = response.choices[0].message.content.replace("```json\n", "").replace("```", "")
|
||||||
|
if ("\n") not in result_str:
|
||||||
|
result_str = result_str.replace("\\", "\n")
|
||||||
|
return result_str
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ def test_db_connection() -> bool:
|
|||||||
This function attempts to establish a connection to an Azure SQL Database
|
This function attempts to establish a connection to an Azure SQL Database
|
||||||
using the connection string stored in the environment variable
|
using the connection string stored in the environment variable
|
||||||
'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect,
|
'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect,
|
||||||
with a timeout of 240 seconds for each attempt.
|
with a timeout of 480 seconds for each attempt.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -22,14 +22,15 @@ def test_db_connection() -> bool:
|
|||||||
for i in range(5):
|
for i in range(5):
|
||||||
print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}")
|
print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}")
|
||||||
try:
|
try:
|
||||||
pyodbc.connect(connection_string, timeout=240)
|
pyodbc.connect(connection_string, timeout=480)
|
||||||
print("Connected to Azure SQL Database successfully!")
|
print("Connected to Azure SQL Database successfully!")
|
||||||
connected = True
|
connected = True
|
||||||
break
|
break
|
||||||
|
|
||||||
except pyodbc.Error as e:
|
except Exception as e:
|
||||||
print(f"Error connecting to Azure SQL Database: {e}")
|
print(f"Error connecting to Azure SQL Database: {e}")
|
||||||
connected = False
|
connected = False
|
||||||
|
continue
|
||||||
|
|
||||||
return connected
|
return connected
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user