feat(ai-chat): Add SQL query field for comparison

In order to compare the (not yet implemented) SQL query generated by
the LLM with an actual query, another text field was added that parses
the query to `pyodbc`, which connects to our database, stores the
resulting rows in a `pandas` dataframe and then visualizes it as a table
in plotly dash.

The SQL functionalities are implemented in the `sql_utils.py` module.

Additionally, some minor updates to the overall behavior and layout of
the app were implemented.
This commit is contained in:
Tobias Quadfasel
2024-09-02 20:43:48 +02:00
parent 923dc3b439
commit 4b9fa0579e
4 changed files with 272 additions and 57 deletions

View File

@@ -6,6 +6,7 @@ from dash import (
Output,
State,
callback,
dash_table,
dcc,
get_asset_url,
html,
@@ -15,12 +16,14 @@ from dash.exceptions import PreventUpdate
from .app_styles import header_style
from .data_chat import send_message
from .sql_utils import execute_query, test_db_connection
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets)
app.index_string = header_style
err_style = {
"height": "0px",
"overflow": "hidden",
@@ -38,63 +41,114 @@ err_style = {
"align-items": "center",
}
start_value = "Stelle deine Frage an die Datenbank..."
app.layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="session"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=err_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
], # Output field
className="container",
)
def get_layout() -> html.Div:
"""Generate the layout for a Dash application.
This function creates a complex layout for a database chat interface
and a direct SQL query interface. It includes various Dash components
such as text areas, buttons, and a loading spinner.
The layout consists of:
- A header with title and logo
- A textarea for user input (database chat)
- A submit button for the database chat
- An error message area
- A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button
- An output area for SQL query results
Returns
-------
html.Div
A Dash html.Div component containing the entire layout of the application.
"""
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "0px"
start_value = "Stelle deine Frage an die Datenbank..."
layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
html.H2(
"Direkte SQL-Abfrage",
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
), # SQL Header
dcc.Textarea(
id="sql-input-field",
value="(Microsoft) SQL-Abfrage eingeben...",
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # SQL Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="sql-submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
],
className="container",
)
return layout
app.layout = html.Div([get_layout()])
@callback(
Output("text-output", "children"),
Output("tmp-value", "data"),
Output("error", "style"),
Output("error", "children"),
Input("submit-button", "n_clicks"),
State("input-field", "value"),
State("tmp-value", "data"),
@@ -108,7 +162,7 @@ app.layout = html.Div(
),
],
)
def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[str, Any]]:
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
"""Update the output based on user input and button clicks.
Parameters
@@ -126,18 +180,91 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[
Updated output text, new stored value, and error style.
"""
global err_style
if n_clicks > 0 and value != data:
print(f"Value: {value}")
print(f"Data: {data}")
db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected:
result = send_message(value)
err_style["height"] = "0px"
return result, value, err_style
return result, value, err_style, html.P("")
elif value == data:
err_style["height"] = "50px"
return no_update, no_update, err_style
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
elif not db_connected:
err_style["height"] = "50px"
err_child = html.P(
(
"Fehler beim Herstellen der Verbindung zur "
"Datenbank. Bitte versuche es später erneut."
)
)
return no_update, no_update, err_style, err_child
raise PreventUpdate
@callback(
Output("sql-output", "children"),
Input("sql-submit-button", "n_clicks"),
State("sql-input-field", "value"),
prevent_initial_call=True,
)
def run_sql_query(n_clicks: int, value: str) -> str:
"""Run a SQL query and return the results.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
Returns
-------
str
The results of the SQL query.
"""
if n_clicks > 0:
result = execute_query(value)
if isinstance(result, str):
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "80px"
tmp_style["padding"] = "20px"
err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
)
return err_child
else:
table_child = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in result.columns],
data=result.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return table_child
raise PreventUpdate
server = app.server
if __name__ == "__main__":

View File

@@ -12,7 +12,7 @@ header_style = """
justify-content: space-between;
align-items: center;
padding: 20px;
background-color: #f8f9fa;
background-color: #ffffff;
}
.heading {
font-size: 2.5em;

View File

@@ -26,10 +26,33 @@ def send_message(message: str) -> str:
str
The content of the assistant's response message.
"""
system_message = """
Du bist ein hilfsbereiter, fröhlicher Datenbankassistent.
Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema:
MEIN_DATENBANKSCHEMA
Füge Spaltenüberschriften in die Abfrageergebnisse ein.
Gib deine Antwort immer im folgenden JSON-Format an:
JSON FORMAT
Gib NUR JSON aus.
Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query,
um die angeforderten Daten abzurufen.
Ersetze in der vorangehenden JSON-Antwort "your-summary" durch eine Zusammenfassung der Abfrage.
Gib immer alle Spalten der Tabelle an.
Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze
trotzdem "your-query" durch eine Zusammenfassung der Abfrage.
Verwende KEINE MySQL-Syntax.
Begrenze die SQL-Abfrage immer auf 100 Zeilen.
"""
system_message = "Du bist ein hilfreicher Assistent."
response = client.chat.completions.create(
model=deployment_name,
messages=[
{"role": "system", "content": "Du bist ein hilfreicher Assistent."},
{"role": "system", "content": system_message},
{"role": "user", "content": message},
],
)

65
app/sql_utils.py Normal file
View File

@@ -0,0 +1,65 @@
import os
from typing import Union
import pandas as pd
import pyodbc
def test_db_connection() -> bool:
"""Test the connection to Azure SQL Database.
This function attempts to establish a connection to an Azure SQL Database
using the connection string stored in the environment variable
'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect,
with a timeout of 240 seconds for each attempt.
Returns
-------
bool
True if the connection was successful, False otherwise.
"""
connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING")
for i in range(5):
print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}")
try:
pyodbc.connect(connection_string, timeout=240)
print("Connected to Azure SQL Database successfully!")
connected = True
break
except pyodbc.Error as e:
print(f"Error connecting to Azure SQL Database: {e}")
connected = False
return connected
def execute_query(query: str) -> Union[pd.DataFrame, str]:
"""Execute a SQL query on an Azure SQL Database and return the results.
This function connects to an Azure SQL Database using the connection string
stored in the environment variable 'AZURE_SQL_CONNECTION_STRING', executes
the provided SQL query, and returns the results as a pandas DataFrame.
Parameters
----------
query : str
The SQL query to execute.
Returns
-------
Union[pd.DataFrame, str]
A pandas DataFrame containing the query results if successful,
or a string containing the error message if an exception occurs.
"""
try:
connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING")
conn = pyodbc.connect(connection_string, timeout=240)
df = pd.read_sql(query, conn)
conn.close()
return df
except Exception as e:
return str(e)
finally:
if conn in locals():
conn.close()