feat/ai-chat: Add core components for Database chat #5
147
app/app.py
147
app/app.py
@@ -6,6 +6,7 @@ from dash import (
|
|||||||
Output,
|
Output,
|
||||||
State,
|
State,
|
||||||
callback,
|
callback,
|
||||||
|
dash_table,
|
||||||
dcc,
|
dcc,
|
||||||
get_asset_url,
|
get_asset_url,
|
||||||
html,
|
html,
|
||||||
@@ -15,12 +16,14 @@ from dash.exceptions import PreventUpdate
|
|||||||
|
|
||||||
from .app_styles import header_style
|
from .app_styles import header_style
|
||||||
from .data_chat import send_message
|
from .data_chat import send_message
|
||||||
|
from .sql_utils import execute_query, test_db_connection
|
||||||
|
|
||||||
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]
|
||||||
|
|
||||||
app = Dash(__name__, external_stylesheets=external_stylesheets)
|
app = Dash(__name__, external_stylesheets=external_stylesheets)
|
||||||
app.index_string = header_style
|
app.index_string = header_style
|
||||||
|
|
||||||
|
|
||||||
err_style = {
|
err_style = {
|
||||||
"height": "0px",
|
"height": "0px",
|
||||||
"overflow": "hidden",
|
"overflow": "hidden",
|
||||||
@@ -38,8 +41,35 @@ err_style = {
|
|||||||
"align-items": "center",
|
"align-items": "center",
|
||||||
}
|
}
|
||||||
|
|
||||||
start_value = "Stelle deine Frage an die Datenbank..."
|
|
||||||
app.layout = html.Div(
|
def get_layout() -> html.Div:
|
||||||
|
"""Generate the layout for a Dash application.
|
||||||
|
|
||||||
|
This function creates a complex layout for a database chat interface
|
||||||
|
and a direct SQL query interface. It includes various Dash components
|
||||||
|
such as text areas, buttons, and a loading spinner.
|
||||||
|
|
||||||
|
The layout consists of:
|
||||||
|
- A header with title and logo
|
||||||
|
- A textarea for user input (database chat)
|
||||||
|
- A submit button for the database chat
|
||||||
|
- An error message area
|
||||||
|
- A loading spinner and output area for database responses
|
||||||
|
- A section for direct SQL queries, including a textarea and submit button
|
||||||
|
- An output area for SQL query results
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
html.Div
|
||||||
|
A Dash html.Div component containing the entire layout of the application.
|
||||||
|
"""
|
||||||
|
global err_style
|
||||||
|
tmp_style = err_style.copy()
|
||||||
|
tmp_style["height"] = "0px"
|
||||||
|
|
||||||
|
start_value = "Stelle deine Frage an die Datenbank..."
|
||||||
|
|
||||||
|
layout = html.Div(
|
||||||
[
|
[
|
||||||
html.Div(
|
html.Div(
|
||||||
[
|
[
|
||||||
@@ -49,7 +79,7 @@ app.layout = html.Div(
|
|||||||
className="header-container",
|
className="header-container",
|
||||||
), # Header
|
), # Header
|
||||||
dcc.Store(
|
dcc.Store(
|
||||||
id="tmp-value", data=start_value, storage_type="session"
|
id="tmp-value", data=start_value, storage_type="memory"
|
||||||
), # Store previous prompt
|
), # Store previous prompt
|
||||||
dcc.Textarea(
|
dcc.Textarea(
|
||||||
id="input-field",
|
id="input-field",
|
||||||
@@ -65,7 +95,7 @@ app.layout = html.Div(
|
|||||||
style={"margin-left": "20px"},
|
style={"margin-left": "20px"},
|
||||||
), # Submit button
|
), # Submit button
|
||||||
html.Div(
|
html.Div(
|
||||||
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=err_style
|
[html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style
|
||||||
), # Error message (only visible if input is not updated but submit button is clicked)
|
), # Error message (only visible if input is not updated but submit button is clicked)
|
||||||
dcc.Loading(
|
dcc.Loading(
|
||||||
id="loading",
|
id="loading",
|
||||||
@@ -86,15 +116,39 @@ app.layout = html.Div(
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
], # Output field
|
html.H2(
|
||||||
|
"Direkte SQL-Abfrage",
|
||||||
|
style={"margin-left": "20px", "margin-top": "50px", "font-size": 24},
|
||||||
|
), # SQL Header
|
||||||
|
dcc.Textarea(
|
||||||
|
id="sql-input-field",
|
||||||
|
value="(Microsoft) SQL-Abfrage eingeben...",
|
||||||
|
style={"width": "96%", "height": 200, "margin-left": "20px"},
|
||||||
|
), # SQL Input field
|
||||||
|
html.Div([]), # Needed for keeping the layout clean
|
||||||
|
html.Button(
|
||||||
|
"Abschicken",
|
||||||
|
id="sql-submit-button",
|
||||||
|
n_clicks=0,
|
||||||
|
disabled=False,
|
||||||
|
style={"margin-left": "20px"},
|
||||||
|
), # Submit button
|
||||||
|
html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output
|
||||||
|
],
|
||||||
className="container",
|
className="container",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return layout
|
||||||
|
|
||||||
|
|
||||||
|
app.layout = html.Div([get_layout()])
|
||||||
|
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Output("text-output", "children"),
|
Output("text-output", "children"),
|
||||||
Output("tmp-value", "data"),
|
Output("tmp-value", "data"),
|
||||||
Output("error", "style"),
|
Output("error", "style"),
|
||||||
|
Output("error", "children"),
|
||||||
Input("submit-button", "n_clicks"),
|
Input("submit-button", "n_clicks"),
|
||||||
State("input-field", "value"),
|
State("input-field", "value"),
|
||||||
State("tmp-value", "data"),
|
State("tmp-value", "data"),
|
||||||
@@ -108,7 +162,7 @@ app.layout = html.Div(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[str, Any]]:
|
def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]:
|
||||||
"""Update the output based on user input and button clicks.
|
"""Update the output based on user input and button clicks.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -126,18 +180,91 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[
|
|||||||
Updated output text, new stored value, and error style.
|
Updated output text, new stored value, and error style.
|
||||||
"""
|
"""
|
||||||
global err_style
|
global err_style
|
||||||
if n_clicks > 0 and value != data:
|
print(f"Value: {value}")
|
||||||
|
print(f"Data: {data}")
|
||||||
|
db_connected = test_db_connection()
|
||||||
|
if n_clicks > 0 and value != data and db_connected:
|
||||||
result = send_message(value)
|
result = send_message(value)
|
||||||
err_style["height"] = "0px"
|
err_style["height"] = "0px"
|
||||||
return result, value, err_style
|
return result, value, err_style, html.P("")
|
||||||
elif value == data:
|
elif value == data:
|
||||||
|
|
||||||
err_style["height"] = "50px"
|
err_style["height"] = "50px"
|
||||||
return no_update, no_update, err_style
|
err_child = html.P("Bitte eine neue Anfrage eingeben.")
|
||||||
|
return no_update, no_update, err_style, err_child
|
||||||
|
elif not db_connected:
|
||||||
|
err_style["height"] = "50px"
|
||||||
|
err_child = html.P(
|
||||||
|
(
|
||||||
|
"Fehler beim Herstellen der Verbindung zur "
|
||||||
|
"Datenbank. Bitte versuche es später erneut."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return no_update, no_update, err_style, err_child
|
||||||
|
|
||||||
raise PreventUpdate
|
raise PreventUpdate
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("sql-output", "children"),
|
||||||
|
Input("sql-submit-button", "n_clicks"),
|
||||||
|
State("sql-input-field", "value"),
|
||||||
|
prevent_initial_call=True,
|
||||||
|
)
|
||||||
|
def run_sql_query(n_clicks: int, value: str) -> str:
|
||||||
|
"""Run a SQL query and return the results.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_clicks : int
|
||||||
|
Number of times the submit button has been clicked.
|
||||||
|
value : str
|
||||||
|
Current value of the input field.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
The results of the SQL query.
|
||||||
|
"""
|
||||||
|
if n_clicks > 0:
|
||||||
|
result = execute_query(value)
|
||||||
|
if isinstance(result, str):
|
||||||
|
global err_style
|
||||||
|
tmp_style = err_style.copy()
|
||||||
|
tmp_style["height"] = "80px"
|
||||||
|
tmp_style["padding"] = "20px"
|
||||||
|
err_child = html.Div(
|
||||||
|
[html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style
|
||||||
|
)
|
||||||
|
return err_child
|
||||||
|
else:
|
||||||
|
table_child = dash_table.DataTable(
|
||||||
|
id="table",
|
||||||
|
columns=[{"name": i, "id": i} for i in result.columns],
|
||||||
|
data=result.to_dict("records"),
|
||||||
|
page_size=10,
|
||||||
|
style_table={
|
||||||
|
"overflowX": "auto",
|
||||||
|
"margin": "auto",
|
||||||
|
"width": "96%",
|
||||||
|
},
|
||||||
|
style_cell={
|
||||||
|
"minWidth": "100px",
|
||||||
|
"width": "150px",
|
||||||
|
"maxWidth": "300px",
|
||||||
|
"overflow": "hidden",
|
||||||
|
"textOverflow": "ellipsis",
|
||||||
|
},
|
||||||
|
style_header={
|
||||||
|
"backgroundColor": "lightblue",
|
||||||
|
"fontWeight": "bold",
|
||||||
|
"color": "black",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return table_child
|
||||||
|
raise PreventUpdate
|
||||||
|
|
||||||
|
|
||||||
server = app.server
|
server = app.server
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ header_style = """
|
|||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
padding: 20px;
|
padding: 20px;
|
||||||
background-color: #f8f9fa;
|
background-color: #ffffff;
|
||||||
}
|
}
|
||||||
.heading {
|
.heading {
|
||||||
font-size: 2.5em;
|
font-size: 2.5em;
|
||||||
|
|||||||
@@ -26,10 +26,33 @@ def send_message(message: str) -> str:
|
|||||||
str
|
str
|
||||||
The content of the assistant's response message.
|
The content of the assistant's response message.
|
||||||
"""
|
"""
|
||||||
|
system_message = """
|
||||||
|
Du bist ein hilfsbereiter, fröhlicher Datenbankassistent.
|
||||||
|
Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema:
|
||||||
|
|
||||||
|
MEIN_DATENBANKSCHEMA
|
||||||
|
|
||||||
|
Füge Spaltenüberschriften in die Abfrageergebnisse ein.
|
||||||
|
|
||||||
|
Gib deine Antwort immer im folgenden JSON-Format an:
|
||||||
|
|
||||||
|
JSON FORMAT
|
||||||
|
|
||||||
|
Gib NUR JSON aus.
|
||||||
|
Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query,
|
||||||
|
um die angeforderten Daten abzurufen.
|
||||||
|
Ersetze in der vorangehenden JSON-Antwort "your-summary" durch eine Zusammenfassung der Abfrage.
|
||||||
|
Gib immer alle Spalten der Tabelle an.
|
||||||
|
Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze
|
||||||
|
trotzdem "your-query" durch eine Zusammenfassung der Abfrage.
|
||||||
|
Verwende KEINE MySQL-Syntax.
|
||||||
|
Begrenze die SQL-Abfrage immer auf 100 Zeilen.
|
||||||
|
"""
|
||||||
|
system_message = "Du bist ein hilfreicher Assistent."
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=deployment_name,
|
model=deployment_name,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "Du bist ein hilfreicher Assistent."},
|
{"role": "system", "content": system_message},
|
||||||
{"role": "user", "content": message},
|
{"role": "user", "content": message},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
65
app/sql_utils.py
Normal file
65
app/sql_utils.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pyodbc
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_connection() -> bool:
|
||||||
|
"""Test the connection to Azure SQL Database.
|
||||||
|
|
||||||
|
This function attempts to establish a connection to an Azure SQL Database
|
||||||
|
using the connection string stored in the environment variable
|
||||||
|
'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect,
|
||||||
|
with a timeout of 240 seconds for each attempt.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the connection was successful, False otherwise.
|
||||||
|
"""
|
||||||
|
connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING")
|
||||||
|
for i in range(5):
|
||||||
|
print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}")
|
||||||
|
try:
|
||||||
|
pyodbc.connect(connection_string, timeout=240)
|
||||||
|
print("Connected to Azure SQL Database successfully!")
|
||||||
|
connected = True
|
||||||
|
break
|
||||||
|
|
||||||
|
except pyodbc.Error as e:
|
||||||
|
print(f"Error connecting to Azure SQL Database: {e}")
|
||||||
|
connected = False
|
||||||
|
|
||||||
|
return connected
|
||||||
|
|
||||||
|
|
||||||
|
def execute_query(query: str) -> Union[pd.DataFrame, str]:
|
||||||
|
"""Execute a SQL query on an Azure SQL Database and return the results.
|
||||||
|
|
||||||
|
This function connects to an Azure SQL Database using the connection string
|
||||||
|
stored in the environment variable 'AZURE_SQL_CONNECTION_STRING', executes
|
||||||
|
the provided SQL query, and returns the results as a pandas DataFrame.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
The SQL query to execute.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Union[pd.DataFrame, str]
|
||||||
|
A pandas DataFrame containing the query results if successful,
|
||||||
|
or a string containing the error message if an exception occurs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING")
|
||||||
|
conn = pyodbc.connect(connection_string, timeout=240)
|
||||||
|
df = pd.read_sql(query, conn)
|
||||||
|
conn.close()
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
return str(e)
|
||||||
|
finally:
|
||||||
|
if conn in locals():
|
||||||
|
conn.close()
|
||||||
Reference in New Issue
Block a user