From 4b9fa0579e42df30990334d1f3c5d393cb774e2a Mon Sep 17 00:00:00 2001 From: Tobias Quadfasel Date: Mon, 2 Sep 2024 20:43:48 +0200 Subject: [PATCH] feat(ai-chat): Add SQL query field for comparison In order to compare the (not yet implemented) SQL query generated by the LLM with an actual query, another text field was added that parses the query to `pyodbc`, which connects to our database, stores the resulting rows in a `pandas` dataframe and then visualizes it as a table in plotly dash. The SQL functionalities are implemented in the `sql_utils.py` module. Additionally, some minor updates to the overall behavior and layout of the app were implemented. --- app/app.py | 237 +++++++++++++++++++++++++++++++++++----------- app/app_styles.py | 2 +- app/data_chat.py | 25 ++++- app/sql_utils.py | 65 +++++++++++++ 4 files changed, 272 insertions(+), 57 deletions(-) create mode 100644 app/sql_utils.py diff --git a/app/app.py b/app/app.py index 797969d..eaf6120 100644 --- a/app/app.py +++ b/app/app.py @@ -6,6 +6,7 @@ from dash import ( Output, State, callback, + dash_table, dcc, get_asset_url, html, @@ -15,12 +16,14 @@ from dash.exceptions import PreventUpdate from .app_styles import header_style from .data_chat import send_message +from .sql_utils import execute_query, test_db_connection external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"] app = Dash(__name__, external_stylesheets=external_stylesheets) app.index_string = header_style + err_style = { "height": "0px", "overflow": "hidden", @@ -38,63 +41,114 @@ err_style = { "align-items": "center", } -start_value = "Stelle deine Frage an die Datenbank..." -app.layout = html.Div( - [ - html.Div( - [ - html.H1("Datenbank-Chat", className="heading"), - html.Img(src=get_asset_url("logo.png"), className="logo"), - ], - className="header-container", - ), # Header - dcc.Store( - id="tmp-value", data=start_value, storage_type="session" - ), # Store previous prompt - dcc.Textarea( - id="input-field", - value=start_value, - style={"width": "96%", "height": 200, "margin-left": "20px"}, - ), # Input field - html.Div([]), # Needed for keeping the layout clean - html.Button( - "Abschicken", - id="submit-button", - n_clicks=0, - disabled=False, - style={"margin-left": "20px"}, - ), # Submit button - html.Div( - [html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=err_style - ), # Error message (only visible if input is not updated but submit button is clicked) - dcc.Loading( - id="loading", - type="default", - children=[ - html.Div( - "Hier erscheint die Antwort der Datenbank.", - id="text-output", - style={ - "whiteSpace": "pre-line", - "margin-top": "30px", - "margin-left": "20px", - "margin-right": "20px", - "border": "2px solid #86bc25", - "border-radius": "15px", - "padding": "20px", - }, - ) - ], - ), - ], # Output field - className="container", -) + +def get_layout() -> html.Div: + """Generate the layout for a Dash application. + + This function creates a complex layout for a database chat interface + and a direct SQL query interface. It includes various Dash components + such as text areas, buttons, and a loading spinner. + + The layout consists of: + - A header with title and logo + - A textarea for user input (database chat) + - A submit button for the database chat + - An error message area + - A loading spinner and output area for database responses + - A section for direct SQL queries, including a textarea and submit button + - An output area for SQL query results + + Returns + ------- + html.Div + A Dash html.Div component containing the entire layout of the application. + """ + global err_style + tmp_style = err_style.copy() + tmp_style["height"] = "0px" + + start_value = "Stelle deine Frage an die Datenbank..." + + layout = html.Div( + [ + html.Div( + [ + html.H1("Datenbank-Chat", className="heading"), + html.Img(src=get_asset_url("logo.png"), className="logo"), + ], + className="header-container", + ), # Header + dcc.Store( + id="tmp-value", data=start_value, storage_type="memory" + ), # Store previous prompt + dcc.Textarea( + id="input-field", + value=start_value, + style={"width": "96%", "height": 200, "margin-left": "20px"}, + ), # Input field + html.Div([]), # Needed for keeping the layout clean + html.Button( + "Abschicken", + id="submit-button", + n_clicks=0, + disabled=False, + style={"margin-left": "20px"}, + ), # Submit button + html.Div( + [html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style + ), # Error message (only visible if input is not updated but submit button is clicked) + dcc.Loading( + id="loading", + type="default", + children=[ + html.Div( + "Hier erscheint die Antwort der Datenbank.", + id="text-output", + style={ + "whiteSpace": "pre-line", + "margin-top": "30px", + "margin-left": "20px", + "margin-right": "20px", + "border": "2px solid #86bc25", + "border-radius": "15px", + "padding": "20px", + }, + ) + ], + ), + html.H2( + "Direkte SQL-Abfrage", + style={"margin-left": "20px", "margin-top": "50px", "font-size": 24}, + ), # SQL Header + dcc.Textarea( + id="sql-input-field", + value="(Microsoft) SQL-Abfrage eingeben...", + style={"width": "96%", "height": 200, "margin-left": "20px"}, + ), # SQL Input field + html.Div([]), # Needed for keeping the layout clean + html.Button( + "Abschicken", + id="sql-submit-button", + n_clicks=0, + disabled=False, + style={"margin-left": "20px"}, + ), # Submit button + html.Div(id="sql-output", style={"margin-top": "10px"}), # SQL Output + ], + className="container", + ) + + return layout + + +app.layout = html.Div([get_layout()]) @callback( Output("text-output", "children"), Output("tmp-value", "data"), Output("error", "style"), + Output("error", "children"), Input("submit-button", "n_clicks"), State("input-field", "value"), State("tmp-value", "data"), @@ -108,7 +162,7 @@ app.layout = html.Div( ), ], ) -def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[str, Any]]: +def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[str, str], Any]: """Update the output based on user input and button clicks. Parameters @@ -126,18 +180,91 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[ Updated output text, new stored value, and error style. """ global err_style - if n_clicks > 0 and value != data: + print(f"Value: {value}") + print(f"Data: {data}") + db_connected = test_db_connection() + if n_clicks > 0 and value != data and db_connected: result = send_message(value) err_style["height"] = "0px" - return result, value, err_style + return result, value, err_style, html.P("") elif value == data: err_style["height"] = "50px" - return no_update, no_update, err_style + err_child = html.P("Bitte eine neue Anfrage eingeben.") + return no_update, no_update, err_style, err_child + elif not db_connected: + err_style["height"] = "50px" + err_child = html.P( + ( + "Fehler beim Herstellen der Verbindung zur " + "Datenbank. Bitte versuche es später erneut." + ) + ) + return no_update, no_update, err_style, err_child raise PreventUpdate +@callback( + Output("sql-output", "children"), + Input("sql-submit-button", "n_clicks"), + State("sql-input-field", "value"), + prevent_initial_call=True, +) +def run_sql_query(n_clicks: int, value: str) -> str: + """Run a SQL query and return the results. + + Parameters + ---------- + n_clicks : int + Number of times the submit button has been clicked. + value : str + Current value of the input field. + + Returns + ------- + str + The results of the SQL query. + """ + if n_clicks > 0: + result = execute_query(value) + if isinstance(result, str): + global err_style + tmp_style = err_style.copy() + tmp_style["height"] = "80px" + tmp_style["padding"] = "20px" + err_child = html.Div( + [html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style + ) + return err_child + else: + table_child = dash_table.DataTable( + id="table", + columns=[{"name": i, "id": i} for i in result.columns], + data=result.to_dict("records"), + page_size=10, + style_table={ + "overflowX": "auto", + "margin": "auto", + "width": "96%", + }, + style_cell={ + "minWidth": "100px", + "width": "150px", + "maxWidth": "300px", + "overflow": "hidden", + "textOverflow": "ellipsis", + }, + style_header={ + "backgroundColor": "lightblue", + "fontWeight": "bold", + "color": "black", + }, + ) + return table_child + raise PreventUpdate + + server = app.server if __name__ == "__main__": diff --git a/app/app_styles.py b/app/app_styles.py index 9d46c7f..9e893b3 100644 --- a/app/app_styles.py +++ b/app/app_styles.py @@ -12,7 +12,7 @@ header_style = """ justify-content: space-between; align-items: center; padding: 20px; - background-color: #f8f9fa; + background-color: #ffffff; } .heading { font-size: 2.5em; diff --git a/app/data_chat.py b/app/data_chat.py index f7c55f2..369f7c7 100644 --- a/app/data_chat.py +++ b/app/data_chat.py @@ -26,10 +26,33 @@ def send_message(message: str) -> str: str The content of the assistant's response message. """ + system_message = """ + Du bist ein hilfsbereiter, fröhlicher Datenbankassistent. + Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema: + + MEIN_DATENBANKSCHEMA + + Füge Spaltenüberschriften in die Abfrageergebnisse ein. + + Gib deine Antwort immer im folgenden JSON-Format an: + + JSON FORMAT + + Gib NUR JSON aus. + Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query, + um die angeforderten Daten abzurufen. + Ersetze in der vorangehenden JSON-Antwort "your-summary" durch eine Zusammenfassung der Abfrage. + Gib immer alle Spalten der Tabelle an. + Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze + trotzdem "your-query" durch eine Zusammenfassung der Abfrage. + Verwende KEINE MySQL-Syntax. + Begrenze die SQL-Abfrage immer auf 100 Zeilen. + """ + system_message = "Du bist ein hilfreicher Assistent." response = client.chat.completions.create( model=deployment_name, messages=[ - {"role": "system", "content": "Du bist ein hilfreicher Assistent."}, + {"role": "system", "content": system_message}, {"role": "user", "content": message}, ], ) diff --git a/app/sql_utils.py b/app/sql_utils.py new file mode 100644 index 0000000..079f6fb --- /dev/null +++ b/app/sql_utils.py @@ -0,0 +1,65 @@ +import os +from typing import Union + +import pandas as pd +import pyodbc + + +def test_db_connection() -> bool: + """Test the connection to Azure SQL Database. + + This function attempts to establish a connection to an Azure SQL Database + using the connection string stored in the environment variable + 'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect, + with a timeout of 240 seconds for each attempt. + + Returns + ------- + bool + True if the connection was successful, False otherwise. + """ + connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING") + for i in range(5): + print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}") + try: + pyodbc.connect(connection_string, timeout=240) + print("Connected to Azure SQL Database successfully!") + connected = True + break + + except pyodbc.Error as e: + print(f"Error connecting to Azure SQL Database: {e}") + connected = False + + return connected + + +def execute_query(query: str) -> Union[pd.DataFrame, str]: + """Execute a SQL query on an Azure SQL Database and return the results. + + This function connects to an Azure SQL Database using the connection string + stored in the environment variable 'AZURE_SQL_CONNECTION_STRING', executes + the provided SQL query, and returns the results as a pandas DataFrame. + + Parameters + ---------- + query : str + The SQL query to execute. + + Returns + ------- + Union[pd.DataFrame, str] + A pandas DataFrame containing the query results if successful, + or a string containing the error message if an exception occurs. + """ + try: + connection_string = os.environ.get("AZURE_SQL_CONNECTION_STRING") + conn = pyodbc.connect(connection_string, timeout=240) + df = pd.read_sql(query, conn) + conn.close() + return df + except Exception as e: + return str(e) + finally: + if conn in locals(): + conn.close()