diff --git a/app/app.py b/app/app.py index eaf6120..32cdacb 100644 --- a/app/app.py +++ b/app/app.py @@ -1,5 +1,8 @@ +import json from typing import Any, Dict, Tuple +import pandas as pd +from app_styles import header_style from dash import ( Dash, Input, @@ -13,16 +16,40 @@ from dash import ( no_update, ) from dash.exceptions import PreventUpdate - -from .app_styles import header_style -from .data_chat import send_message -from .sql_utils import execute_query, test_db_connection +from data_chat import send_message +from sql_utils import execute_query, test_db_connection external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"] app = Dash(__name__, external_stylesheets=external_stylesheets) app.index_string = header_style +notification_md = """ +**Hinweise:** + +Aufgrund des sparsamen pricing Tiers kann es einige Sekunden dauern, bis die +Verbindung zur Datenbank hergestellt wird. Im Falle eines Fehlers gern ein-zwei mal erneut +versuchen. + +GPT-4o kann einige Fehler machen. Sollte dies passieren wird eine Fehlermeldung angezeigt. +In diesem Fall lohnt es sich oft, die Anfrage leicht verändert erneut zu stellen und evtl +zusätzliche Informationen zu geben. + +Das Modell ist dazu aufgefordert, den Output stets auf 100 Zeilen zu begrenzen. + +Alle Daten sind komplett zufällig generiert und haben keine Beziehung zu realen Personen. + +**Beispielfragen**: +- Wie viele Kunden haben wir in Hannover? +- Zeige alle Kunden in Bremen. +- Berechne den gesamten Stromverbrauch aller Kunden in Magdeburg. +- Zeige alle Kunden, die zwischen 2021 und 2022 mindestens 200 Kubikmeter Gas verbraucht haben. +- Wie viele Kunden haben zwischen 2021 und 2022 weniger Strom verbraucht als zwischen 2022 +und 2023? + +Weitere Informationen zu den Daten, dem Code sowie zur Nutzung befinden sich in der README im +[GiTea Repository](https://gitea.captain.particlephysics.de/quadfaselt/grid_application). +""" err_style = { "height": "0px", @@ -42,6 +69,46 @@ err_style = { } +def render_table(df: pd.DataFrame) -> dash_table.DataTable: + """Create a Dash DataTable from a pandas DataFrame. + + Parameters + ---------- + df : pd.DataFrame + The input DataFrame to be rendered as a table. + + Returns + ------- + dash_table.DataTable + A Dash DataTable component with styled layout and pagination. + """ + tab = dash_table.DataTable( + id="table", + columns=[{"name": i, "id": i} for i in df.columns], + data=df.to_dict("records"), + page_size=10, + style_table={ + "overflowX": "auto", + "margin": "auto", + "width": "96%", + "margin-top": "20px", + }, + style_cell={ + "minWidth": "100px", + "width": "150px", + "maxWidth": "300px", + "overflow": "hidden", + "textOverflow": "ellipsis", + }, + style_header={ + "backgroundColor": "lightblue", + "fontWeight": "bold", + "color": "black", + }, + ) + return tab + + def get_layout() -> html.Div: """Generate the layout for a Dash application. @@ -53,6 +120,7 @@ def get_layout() -> html.Div: - A header with title and logo - A textarea for user input (database chat) - A submit button for the database chat + - A Notification about the connection time and LLM performance - An error message area - A loading spinner and output area for database responses - A section for direct SQL queries, including a textarea and submit button @@ -78,6 +146,10 @@ def get_layout() -> html.Div: ], className="header-container", ), # Header + html.Div( + "Ganz ohne SQL-Kenntnisse Daten zu Zählerstandmessungen unserer Kunden abrufen!", + style={"margin-left": "20px", "font-weight": "bold", "font-size": "20px"}, + ), dcc.Store( id="tmp-value", data=start_value, storage_type="memory" ), # Store previous prompt @@ -94,6 +166,17 @@ def get_layout() -> html.Div: disabled=False, style={"margin-left": "20px"}, ), # Submit button + dcc.Markdown( + notification_md, + style={ + "margin-left": "20px", + "margin-top": "20px", + "margin-right": "10px", + "background-color": "#C7E6F5", + "border-radius": "5px", + "padding": "10px", + }, + ), html.Div( [html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=tmp_style ), # Error message (only visible if input is not updated but submit button is clicked) @@ -102,7 +185,7 @@ def get_layout() -> html.Div: type="default", children=[ html.Div( - "Hier erscheint die Antwort der Datenbank.", + "Hier erscheint die Antwort des KI-Modells.", id="text-output", style={ "whiteSpace": "pre-line", @@ -180,18 +263,28 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[ Updated output text, new stored value, and error style. """ global err_style - print(f"Value: {value}") - print(f"Data: {data}") db_connected = test_db_connection() if n_clicks > 0 and value != data and db_connected: result = send_message(value) err_style["height"] = "0px" - return result, value, err_style, html.P("") - elif value == data: - err_style["height"] = "50px" - err_child = html.P("Bitte eine neue Anfrage eingeben.") - return no_update, no_update, err_style, err_child + # parse LLM response to dict, then try to execute the query + try: + parsed_result = json.loads(result, strict=False) + + result_table = execute_query(parsed_result["query"]) + children = [ + html.P([html.B("Zusammenfassung: "), f"{parsed_result['summary']}"]), + html.P([html.B("SQL Abfrage: "), f"{parsed_result['query']}"]), + render_table(result_table), + ] + return children, value, err_style, html.P("") + except Exception as e: + err_style["height"] = "400px" + err_child = html.Div(f"Folgender Fehler ist aufgetreten: {e}.LLM Output: {result}.") + + return no_update, value, err_style, err_child + elif not db_connected: err_style["height"] = "50px" err_child = html.P( @@ -202,6 +295,12 @@ def update_output(n_clicks: int, value: str, data: str) -> Tuple[Any, Any, Dict[ ) return no_update, no_update, err_style, err_child + elif value == data: + + err_style["height"] = "50px" + err_child = html.P("Bitte eine neue Anfrage eingeben.") + return no_update, no_update, err_style, err_child + raise PreventUpdate @@ -231,37 +330,14 @@ def run_sql_query(n_clicks: int, value: str) -> str: if isinstance(result, str): global err_style tmp_style = err_style.copy() - tmp_style["height"] = "80px" + tmp_style["height"] = "100px" tmp_style["padding"] = "20px" err_child = html.Div( [html.P(f"Fehler bei der Ausführung der Abfrage: {result}")], style=tmp_style ) return err_child else: - table_child = dash_table.DataTable( - id="table", - columns=[{"name": i, "id": i} for i in result.columns], - data=result.to_dict("records"), - page_size=10, - style_table={ - "overflowX": "auto", - "margin": "auto", - "width": "96%", - }, - style_cell={ - "minWidth": "100px", - "width": "150px", - "maxWidth": "300px", - "overflow": "hidden", - "textOverflow": "ellipsis", - }, - style_header={ - "backgroundColor": "lightblue", - "fontWeight": "bold", - "color": "black", - }, - ) - return table_child + return render_table(result) raise PreventUpdate diff --git a/app/data_chat.py b/app/data_chat.py index 369f7c7..4cc3106 100644 --- a/app/data_chat.py +++ b/app/data_chat.py @@ -1,16 +1,24 @@ import os -from openai import AzureOpenAI +from openai import OpenAI + +# from openai import AzureOpenAI # Set up credentials -# NOTE: When running locally, these have to be set in the environment -client = AzureOpenAI( - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), - api_key=os.getenv("AZURE_OPENAI_KEY"), - api_version="2024-02-01", -) +# NOTE: Usually I would use AzureOpenAI, but due to heavy rate +# limitations on azure trial accounts, I am using OpenAI directly +# for this project. However, this is how it would look like for +# AzureOpenAI (credentials must be provided to environment): +# client = AzureOpenAI( +# azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), +# api_key=os.getenv("AZURE_OPENAI_KEY"), +# api_version="2024-02-01", +# ) +# MODEL = "sqlai" # deployment name -deployment_name = "sqlai" +# Set up the OpenAI client +MODEL = "gpt-4o" +client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) def send_message(message: str) -> str: @@ -28,15 +36,60 @@ def send_message(message: str) -> str: """ system_message = """ Du bist ein hilfsbereiter, fröhlicher Datenbankassistent. + Du hilfst Benutzern bei der Erstellung von SQL-Abfragen für eine Datenbank eines + großen Energieversorgungsunternehmens. Die Datenbank enthält Tabellen für Adressen, + Zähler, Kunden und Ablesungen. Es werden Gaszähler (MeterType 'GAS') und Stromzähler + (MeterType 'ELT')unterschieden. + + Besonders wichtig ist, dass die Ablesungen der Werte kumulativ sind. Wenn nach dem Verbrauch + gefragt wird, sollte der Unterschied zwischen zwei aufeinanderfolgenden Ablesungen berechnet + werden. + Verwende beim Erstellen Ihrer Antworten das folgende Datenbankschema: - MEIN_DATENBANKSCHEMA + CREATE TABLE Addresses ( + ID INT PRIMARY KEY IDENTITY(1,1), + StreetName NVARCHAR(100), + HouseNumber NVARCHAR(10), + City NVARCHAR(50), + PostalCode NVARCHAR(10), + Longitude FLOAT, + Latitude FLOAT + ); + + CREATE TABLE Meters ( + ID INT PRIMARY KEY IDENTITY(1,1), + Signature NVARCHAR(11), + MeterType NVARCHAR(3), + AddressID INT, + FOREIGN KEY (AddressID) REFERENCES Addresses(ID) + ); + + CREATE TABLE Customers ( + ID INT PRIMARY KEY IDENTITY(1,1), + FirstName NVARCHAR(100), + LastName NVARCHAR(100), + GasMeterID INT, + EltMeterID INT, + FOREIGN KEY (GasMeterID) REFERENCES Meters(ID), + FOREIGN KEY (EltMeterID) REFERENCES Meters(ID) + ); + + CREATE TABLE Readings ( + ID INT PRIMARY KEY IDENTITY(1,1), + CustomerID INT, + MeterID INT, + ReadingDate DATE, + ReadingValue INT, + FOREIGN KEY (CustomerID) REFERENCES Customers(ID), + FOREIGN KEY (MeterID) REFERENCES Meters(ID) + ); Füge Spaltenüberschriften in die Abfrageergebnisse ein. Gib deine Antwort immer im folgenden JSON-Format an: - JSON FORMAT + { "summary": "your-summary", "query": "your-query" } Gib NUR JSON aus. Ersetze in der vorangehenden JSON-Antwort "your-query" durch die Microsoft SQL Server Query, @@ -45,15 +98,20 @@ def send_message(message: str) -> str: Gib immer alle Spalten der Tabelle an. Wenn die resultierende Abfrage nicht ausführbar ist, ersetze "your-query“ durch NA, aber ersetze trotzdem "your-query" durch eine Zusammenfassung der Abfrage. - Verwende KEINE MySQL-Syntax. + Verwende KEINE MySQL-Syntax, sondern AUSSCHLIESSLICH Microsoft SQL. Begrenze die SQL-Abfrage immer auf 100 Zeilen. + Formatiere den Output bestmöglich. """ - system_message = "Du bist ein hilfreicher Assistent." + response = client.chat.completions.create( - model=deployment_name, + model=MODEL, messages=[ {"role": "system", "content": system_message}, {"role": "user", "content": message}, ], ) - return response.choices[0].message.content + + result_str = response.choices[0].message.content.replace("```json\n", "").replace("```", "") + if ("\n") not in result_str: + result_str = result_str.replace("\\", "\n") + return result_str diff --git a/app/sql_utils.py b/app/sql_utils.py index 079f6fb..c45f05d 100644 --- a/app/sql_utils.py +++ b/app/sql_utils.py @@ -11,7 +11,7 @@ def test_db_connection() -> bool: This function attempts to establish a connection to an Azure SQL Database using the connection string stored in the environment variable 'AZURE_SQL_CONNECTION_STRING'. It makes up to 5 attempts to connect, - with a timeout of 240 seconds for each attempt. + with a timeout of 480 seconds for each attempt. Returns ------- @@ -22,14 +22,15 @@ def test_db_connection() -> bool: for i in range(5): print(f"Trying to connect to Azure SQL Database... Attempt {i + 1}") try: - pyodbc.connect(connection_string, timeout=240) + pyodbc.connect(connection_string, timeout=480) print("Connected to Azure SQL Database successfully!") connected = True break - except pyodbc.Error as e: + except Exception as e: print(f"Error connecting to Azure SQL Database: {e}") connected = False + continue return connected