feat/ai-chat: Add core components for Database chat #5

Merged
quadfaselt merged 5 commits from feat/ai-chat into main 2024-09-03 13:38:24 +00:00
4 changed files with 272 additions and 57 deletions
Showing only changes of commit 4b9fa0579e - Show all commits

View File

@@ -6,6 +6,7 @@ from dash import (
Output,
State,
callback,
dash_table,
dcc,
get_asset_url,
html,
@@ -15,12 +16,14 @@ from dash.exceptions import PreventUpdate
from .app_styles import header_style
from .data_chat import send_message
from .sql_utils import execute_query, test_db_connection
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
app = Dash(__name__, external_stylesheets=external_stylesheets)
app.index_string = header_style
err_style = {
"height": "0px",
"overflow": "hidden",
@@ -38,63 +41,114 @@ err_style = {
"align-items": "center",
}
start_value = "Stelle deine Frage an die Datenbank..."
app.layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="session"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=err_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
], # Output field
className="container",
)
def get_layout() -> html.Div:
"""Generate the layout for a Dash application.
This function creates a complex layout for a database chat interface
and a direct SQL query interface. It includes various Dash components
such as text areas, buttons, and a loading spinner.
The layout consists of:
- A header with title and logo
- A textarea for user input (database chat)
- A submit button for the database chat
- An error message area
- A loading spinner and output area for database responses
- A section for direct SQL queries, including a textarea and submit button
- An output area for SQL query results
Returns
-------
html.Div
A Dash html.Div component containing the entire layout of the application.
"""
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "0px"
start_value = "Stelle deine Frage an die Datenbank..."
layout = html.Div(
[
html.Div(
[
html.H1("Datenbank-Chat", className="heading"),
html.Img(src=get_asset_url("logo.png"), className="logo"),
],
className="header-container",
), # Header
dcc.Store(
id="tmp-value", data=start_value, storage_type="memory"
), # Store previous prompt
dcc.Textarea(
id="input-field",
value=start_value,
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
), # Error message (only visible if input is not updated but submit button is clicked)
dcc.Loading(
id="loading",
type="default",
children=[
html.Div(
"Hier erscheint die Antwort der Datenbank.",
id="text-output",
style={
"whiteSpace": "pre-line",
"margin-top": "30px",
"margin-left": "20px",
"margin-right": "20px",
"border": "2px solid #86bc25",
"border-radius": "15px",
"padding": "20px",
},
)
],
),
html.H2(
"Direkte SQL-Abfrage",
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
), # SQL Header
dcc.Textarea(
id="sql-input-field",
value="(Microsoft) SQL-Abfrage eingeben...",
style={"width": "96%", "height": 200, "margin-left": "20px"},
), # SQL Input field
html.Div([]), # Needed for keeping the layout clean
html.Button(
"Abschicken",
id="sql-submit-button",
n_clicks=0,
disabled=False,
style={"margin-left": "20px"},
), # Submit button
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
],
className="container",
)
return layout
app.layout = html.Div([get_layout()])
@callback(
Output("text-output", "children"),
Output("tmp-value", "data"),
Output("error", "style"),
Output("error", "children"),
Input("submit-button", "n_clicks"),
State("input-field", "value"),
State("tmp-value", "data"),
@@ -108,7 +162,7 @@ app.layout = html.Div(
),
],
)
def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[str, Any]]:
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
"""Update the output based on user input and button clicks.
Parameters
@@ -126,18 +180,91 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[
Updated output text, new stored value, and error style.
"""
global err_style
if n_clicks > 0 and value != data:
print(f"Value: {value}")
print(f"Data: {data}")
db_connected = test_db_connection()
if n_clicks > 0 and value != data and db_connected:
result = send_message(value)
err_style["height"] = "0px"
return result, value, err_style
return result, value, err_style, html.P("")
elif value == data:
err_style["height"] = "50px"
return no_update, no_update, err_style
err_child = html.P("Bitte eine neue Anfrage eingeben.")
return no_update, no_update, err_style, err_child
elif not db_connected:
err_style["height"] = "50px"
err_child = html.P(
(
"Fehler beim Herstellen der Verbindung zur "
"Datenbank. Bitte versuche es später erneut."
)
)
return no_update, no_update, err_style, err_child
raise PreventUpdate
@callback(
Output("sql-output", "children"),
Input("sql-submit-button", "n_clicks"),
State("sql-input-field", "value"),
prevent_initial_call=True,
)
def run_sql_query(n_clicks: int, value: str) -> str:
"""Run a SQL query and return the results.
Parameters
----------
n_clicks : int
Number of times the submit button has been clicked.
value : str
Current value of the input field.
Returns
-------
str
The results of the SQL query.
"""
if n_clicks > 0:
result = execute_query(value)
if isinstance(result, str):
global err_style
tmp_style = err_style.copy()
tmp_style["height"] = "80px"
tmp_style["padding"] = "20px"
err_child = html.Div(
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
)
return err_child
else:
table_child = dash_table.DataTable(
id="table",
columns=[{"name": i, "id": i} for i in result.columns],
data=result.to_dict("records"),
page_size=10,
style_table={
"overflowX": "auto",
"margin": "auto",
"width": "96%",
},
style_cell={
"minWidth": "100px",
"width": "150px",
"maxWidth": "300px",
"overflow": "hidden",
"textOverflow": "ellipsis",
},
style_header={
"backgroundColor": "lightblue",
"fontWeight": "bold",
"color": "black",
},
)
return table_child
raise PreventUpdate
server = app.server
if __name__ == "__main__":

View File

@@ -12,7 +12,7 @@ header_style = """
justify-content: space-between;
align-items: center;
padding: 20px;
background-color: #f8f9fa;
background-color: #ffffff;
}
.heading {
font-size: 2.5em;

View File

@@ -26,10 +26,33 @@ def send_message(message: str) -> str:
str
The content of the assistant's response message.
"""
system_message = """
Du bist ein hilfsbereiter, fröhlicher Datenbankassistent.
Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema:
MEIN_DATENBANKSCHEMA
Füge Spaltenüberschriften in die Abfrageergebnisse ein.
Gib deine Antwort immer im folgenden JSON-Format an:
JSON FORMAT
Gib NUR JSON aus.
Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query,
um die angeforderten Daten abzurufen.
Ersetze in der vorangehenden JSON-Antwort "your-summary" durch eine Zusammenfassung der Abfrage.
Gib immer alle Spalten der Tabelle an.
Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze
trotzdem "your-query" durch eine Zusammenfassung der Abfrage.
Verwende KEINE MySQL-Syntax.
Begrenze die SQL-Abfrage immer auf 100 Zeilen.
"""
system_message = "Du bist ein hilfreicher Assistent."
response = client.chat.completions.create(
model=deployment_name,
messages=[
{"role": "system", "content": "Du bist ein hilfreicher Assistent."},
{"role": "system", "content": system_message},
{"role": "user", "content": message},
],
)

65
app/sql_utils.py Normal file
View File

@@ -0,0 +1,65 @@
import os
from typing import Union
import pandas as pd
import pyodbc
def test_db_connection() -> bool:
"""Test the connection to Azure SQL Database.
This function attempts to establish a connection to an Azure SQL Database
using the connection string stored in the environment variable
'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect,
with a timeout of 240 seconds for each attempt.
Returns
-------
bool
True if the connection was successful, False otherwise.
"""
connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING")
for i in range(5):
print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}")
try:
pyodbc.connect(connection_string, timeout=240)
print("Connected to Azure SQL Database successfully!")
connected = True
break
except pyodbc.Error as e:
print(f"Error connecting to Azure SQL Database: {e}")
connected = False
return connected
def execute_query(query: str) -> Union[pd.DataFrame, str]:
"""Execute a SQL query on an Azure SQL Database and return the results.
This function connects to an Azure SQL Database using the connection string
stored in the environment variable 'AZURE_SQL_CONNECTION_STRING', executes
the provided SQL query, and returns the results as a pandas DataFrame.
Parameters
----------
query : str
The SQL query to execute.
Returns
-------
Union[pd.DataFrame, str]
A pandas DataFrame containing the query results if successful,
or a string containing the error message if an exception occurs.
"""
try:
connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING")
conn = pyodbc.connect(connection_string, timeout=240)
df = pd.read_sql(query, conn)
conn.close()
return df
except Exception as e:
return str(e)
finally:
if conn in locals():
conn.close()