From 923dc3b439c03743d99e39bae3edad0763f80fca Mon Sep 17 00:00:00 2001 From: Tobias Quadfasel Date: Sat, 31 Aug 2024 23:38:14 +0200 Subject: [PATCH] feat(ai-chat): Add first version of ai chat as well as frontend Includes the first version of a rudimentary chat app, still without the SQL capabilities that we want later. For now, we can connect to the Azure OpenAI source and then have the response displayed in a plotly dash webapp. Some styling and UI elements were also added, such as logos. UI components are designed that the user cannot enter the same query twice and cannot click the submit button as long as the query is running. --- app/app.py | 140 +++++++++++++++++++++++++++++++++++++++++++- app/app_styles.py | 37 ++++++++++++ app/assets/logo.png | Bin 0 -> 9441 bytes app/data_chat.py | 36 ++++++++++++ 4 files changed, 210 insertions(+), 3 deletions(-) create mode 100644 app/app_styles.py create mode 100644 app/assets/logo.png create mode 100644 app/data_chat.py diff --git a/app/app.py b/app/app.py index a83c9fc..797969d 100644 --- a/app/app.py +++ b/app/app.py @@ -1,8 +1,142 @@ -from dash import Dash, html +from typing import Any, Dict, Tuple -app = Dash() +from dash import ( + Dash, + Input, + Output, + State, + callback, + dcc, + get_asset_url, + html, + no_update, +) +from dash.exceptions import PreventUpdate + +from .app_styles import header_style +from .data_chat import send_message + +external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"] + +app = Dash(__name__, external_stylesheets=external_stylesheets) +app.index_string = header_style + +err_style = { + "height": "0px", + "overflow": "hidden", + "transition": "height 0.5s ease-in-out", + "border-radius": "15px", + "background-color": "#FFCCCB", + "text-align": "center", + "color": "#FF6B6B", + "margin-top": "20px", + "margin-left": "20px", + "margin-right": "20px", + "font-weight": "bold", + "display": "flex", + "justify-content": "center", + "align-items": "center", +} + +start_value = "Stelle deine Frage an die Datenbank..." +app.layout = html.Div( + [ + html.Div( + [ + html.H1("Datenbank-Chat", className="heading"), + html.Img(src=get_asset_url("logo.png"), className="logo"), + ], + className="header-container", + ), # Header + dcc.Store( + id="tmp-value", data=start_value, storage_type="session" + ), # Store previous prompt + dcc.Textarea( + id="input-field", + value=start_value, + style={"width": "96%", "height": 200, "margin-left": "20px"}, + ), # Input field + html.Div([]), # Needed for keeping the layout clean + html.Button( + "Abschicken", + id="submit-button", + n_clicks=0, + disabled=False, + style={"margin-left": "20px"}, + ), # Submit button + html.Div( + [html.P("Bitte eine neue Anfrage eingeben.")], id="error", style=err_style + ), # Error message (only visible if input is not updated but submit button is clicked) + dcc.Loading( + id="loading", + type="default", + children=[ + html.Div( + "Hier erscheint die Antwort der Datenbank.", + id="text-output", + style={ + "whiteSpace": "pre-line", + "margin-top": "30px", + "margin-left": "20px", + "margin-right": "20px", + "border": "2px solid #86bc25", + "border-radius": "15px", + "padding": "20px", + }, + ) + ], + ), + ], # Output field + className="container", +) + + +@callback( + Output("text-output", "children"), + Output("tmp-value", "data"), + Output("error", "style"), + Input("submit-button", "n_clicks"), + State("input-field", "value"), + State("tmp-value", "data"), + prevent_initial_call=True, + running=[ + (Output("submit-button", "disabled"), True, False), + ( + Output("submit-button", "style"), + {"opacity": 0.5, "margin-left": "20px"}, + {"opacity": 1.0, "margin-left": "20px"}, + ), + ], +) +def update_output(n_clicks: int, value: str, data: str) -> Tuple[str, str, Dict[str, Any]]: + """Update the output based on user input and button clicks. + + Parameters + ---------- + n_clicks : int + Number of times the submit button has been clicked. + value : str + Current value of the input field. + data : str + Previously stored value. + + Returns + ------- + Tuple[str, str, Dict[str, Any]] + Updated output text, new stored value, and error style. + """ + global err_style + if n_clicks > 0 and value != data: + result = send_message(value) + err_style["height"] = "0px" + return result, value, err_style + elif value == data: + + err_style["height"] = "50px" + return no_update, no_update, err_style + + raise PreventUpdate -app.layout = [html.Div(children="Hello World")] server = app.server diff --git a/app/app_styles.py b/app/app_styles.py new file mode 100644 index 0000000..9d46c7f --- /dev/null +++ b/app/app_styles.py @@ -0,0 +1,37 @@ +header_style = """ + + + + {%metas%} + {%title%} + {%favicon%} + {%css%} + + + + {%app_entry%} + + + +""" diff --git a/app/assets/logo.png b/app/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..9399c006b5addefbfeff15c328023abc074fa4fb GIT binary patch literal 9441 zcmbW7WmHsO*!IsbFobjt-5}ipf20JI2I+2)2I*!92`T9kK}ryiE~%j#=@teAB&5@! zp7H(uet4d<)`@l2S$m(e_dff+uHU|6HPv6?;ZWfK0Dz~YD6a(oAYat=8!QOwcx#C^ zg*suoD;jzN059{u9pqOc?Sr~W>7`)crR`?tbxiigzwNbInUw_M-xXsL#j51< z=g&t7`C0<~{QS}w&D^g~PEMZ6wIh!HJsIWK{j;~Xw+Qz%ye&F9`W59ShA#mDL6TtK zlOsCRW*hVWD_(0~UhNz;3U7-NFBvoW>jzFSJ5%$L!O@9D5r-NNX;dI&? z-f{{Q?cT#{^g3)nTV2;P+Z_+`Q4 zK}NsQZuGYQ0avTiK2K!%!Z%b^PAzIjNj|7>UYJ~73Qb?|rhtjYqVyj3zXFRR%y=C; zImb8L1K$@P2wnyqj$z9(3oYr0MHu9jD<)0FzlCX`OGbTO!}K)c2uaGm68eQ_fA7ev zSzG^cXOSaTWJ6mFE<_g3ub8wD|F%yDTR{@NqwCQvyue(y()=wkfzk`Iv!=xTR3Wx~ zvYQunV}Rvo+{veB(V4o=nX5UFXjhTNkqIZ{aO2^aHf}TA~8eP)V( zvW+{2>Hb}fgfF5NA99p1U;3%mBccdFmz0WQLi#<8Y==Y#4F6p?Y5Y9WH|wm@SSUUvbCzldWrpC*UD_~R)cS-U~#XrLX#o@I;P2LA!@C|-pd*% zSG;i+g>}$WIxz7Rq4^`i-xmHQmM@(3fct`$cI}wGZ4zXG9@q@1u zj1gU!rPXNAL$7vbKhS&YspSji2vZxb7Ct9DH7$rK-sGJ7DFL$#wZQtULHhlDW> zRX9EAFLu$nFHPjb(Y@SmxzE&7j|Xud=s@fQ3oW?L80YU(2a|5c8|y{4aRgp^8v98R zxuTVQBfP@gAdCrBdGF>@xwp@`CF2?)LDacokcnI$yuClW+^BCH(1J-}`nuIczzG=R zoQj@IAJZH(XXsOo4}Zl>`|JadzIqFm0A_`zO6BYJ6D1ZRBYfN@F1@ND4BwBxDDr%0 zSkRc%IA?$^g5a7X(-Ay$?CFYnk^`l!w207`r)UxYAvj+EJ2y9VW5@56OZP-vRH3-x zJ(=2o9_-Oy3{^BCe?&}3tM6kQp9p)!j?%L09@abJ;r-j#-C8DeAD|3Wk^A}7q`h^x@w7{2UjP0UxwvC9$BDKMz6~rb zqrF7i!#R?}pKNVB;K@Hbjc>3+H)4TKlWC@vyHy;Wlyh~sEiAjeswlufZ$f3IcV>+4 zo)nZ%pLkkn`Fv}>=l#545y|k^z^uRO7rY2%tena$dN+q^WVBVF2X5&l=4?9Gu)S+V z=R4A7G~rgPPE|(e;~X`g6x_pIe==$0mIVy9?7zG98%)EgxQ-v*7Xesicw&R>A_^ph ztDpQpvpD}OBBs1|duUQE%PMgU*x9O?>F)C6{A*#z%r9VqcT4POkyr4110TD$ z_g4!ti%iF9(BF`TNAoJPpr7GhQE-CM?*+Eanxw<7mxN8+-_*LCe`j>Y!xj|oKWuAt z{T=8c9l5Ey+K+|I{0B5OJ>cL8nN%(Rq7@q4ikEW&W{3QA44^kkp2n9v{a4luW605* zc=oB>Wdk%#?oJQ8g2%2uk`7STzwbk!1JVsF6+fDTTKgju4$2_Ix-1d41P&rJ!Df*8 z7$hTZSV&z^utAV+@U^Gs9poGglq#Grf zWFA(_SYULr=6#GOWU*WEoE(8XE*n&kdxMcr+;@bC%7*q<$sjr?wnI`>GJ})%BWj3g zVaRU3bO}Uz-RQ!8pF;Ul49;0rSj3E}o$m7O z5deG_q+?gfcUfMW4FDc_#BR869kS_5fA~(34n0TMtL%VVim!iyU}1d9i>KyP4-j|> za=+_w3)`CB*TdS=cy1w3Tk(~Yd`oQ!Gbv{iJ!|Z|CxF?V=MjJhnbD+ zImPsh@1P^ZIs=$ZS&QYgH{R z?4?1YIqD$Vcn#r&Zl#EVLLl*P49dsf*HT9XuzNo3qTt*Wy30OQp7ByJ8P!~e__Y@i z2(%>qHtcBn9eWW4XWcL*n3)au)&K9uh)~KVAPi2_|}#7=A@`>cV!` zCHblEux@^mrI(+a|K@m#i!z8MI_UeEz^fvXpNOqaM!#;0Aa5+j<^QR z_pKi{&nc17z&DBp4h1?H%q)F}XPZ3vG>?z&Wf#Btt~Q7EwrY(8rBm4~?DC_WOO6A{itzKUidw7b z>pWA9b%$h!fFO!Gd+|?3*3w-1Damy7528$pqg*yc$vTs-Tb->BuYQBtu@NY<057JT zQr!P&AM9BE)2nUqjwev)x$<+StM>E3VdX~Gi-3ygz}$CYFC2m~w!pV5AFp|8 zNESf_0@IfUSDM$Y7RYf-{QoMHMTSW!z}6>oKB$1-=h)O<-H0Ryx+%2d$7w3LW9;+6 zdQKzyubB6!qj3Jn9Q&J3h@SqZI4H$HcwBZVKKsH#PgUPG0Zy5aJbUSy`?Sx|=R#&m zy0!K#?n~DBNbWdgp7P@W2wc5(3On%uv3j&XA2!e8Yd~Y%z^XEqUjg;D*Ah-KT7~Z2 z#c>6S7_TkhI_$w_J`Tx%ff(oN_tD|mMQZ2f{xicJ|H6OKWw{&LjW|F;=~HZ@jJ|}J zuuph7ts&m(J>2e+i&d9?RJr^0b2aDFpSULU3B}blE4rL|^-@0&DQxvPPOSW-BgeNfmlfHdm@1F%tk%lluT{D><~YS4Uy5$TvU=yW_5A_a zS8Dh!24m30oBZgGW(Zs?T@$&{X%oF9>yL3I0{Q*gI{g(QbYUCOO_NX@qF7LEXoyEQ zJbgJ+(f#3I69X3WP@Q{AdFRBakHkHD;~MIyRZ^yAJo&y5TQ~*uR%pP-~L+fV>6?0zV&h(eR+Qh#I;A7CGMps;$dV36pB?U*X9gBVo zY&wj0OFy4!|MNoVEy}u5J54y~$w(Git~Z37FNjErhVtV(j^X$Pc4Oh3n$#vufp*?UXQl zfkI=fV#{0IgyC$DO@Tqog_Wg8R>R@2OSR-KMZGn0T%PC0BePe(X{JU{WZ(uh9WLsWlr{J`! z1MBlJE1`x_5L+$zvwNhSdL$MT0A>l}>smU=9XHF_Bz+QKqS*E&;tMWf9|-@#*DXbk zLOABOO9r`uKAp4U?lKx&PcE+8A`b8Q8R+rp(Dzn=9B8Jt|_hHg1 z%bwvs)|mr7-~hNh#H)?u0pJ^$K&Q$VRpKV`_zx|`DyVP9o1j>UHZ$OLimEJ^e>U&@ zlVfi0J=|q#NKZBo`9gcuzq(A!k}C3i1)IAw+K~c7#6YB$IcrfD=_K73wW~)1e0bJJ zZ^xXSf^$?7acm}_9lwXtkmQZZL$B%1r{t7Jp_!sob_UCAL~9vM4%8c@ih#vgEH@SI z^qH}qyqP<_Ot)9wafoZ;shBV016rC}rMW+WT*t9Xnh;w@(ym4;Cg=B)jAK!etx zlE@YleH4cj3Ab_9NWmti?}^P`_zHuqe9lNJUo%0edC#m&h1Iyx2q+ApKD+wgq4Uiw zO}|)lnBlJO+-;d99I%&RdJ0yJh&p2?e-Wgfp9V$RcLm;U(JBU|jzb&Y^3*7(}KZ9fcu85=@0t zxNTr|AUNuCE@RDwWdZL$&$-xR8e#X2+o;cu$?y^it8Rr>SQra(8NErEp!%85ZefX- zcF&s+!jZXuq^{-PwQ+b9F-|+-;r76I+i_vi%)q0i_HUlcRn+SsfcvQRoH8_-XPZnc z;vDhNe*Q2~L5;Tch=av`_0aI~1rl?0wZ`QYsv>ScMNK-YK+wR@pAyEnZ4`U*wEh=j zEp=3UhMqEh`=IxYQ>rUUNf9BxN{>1%_v;^}i7L!ULcHoVAB=$4`*WtBN#)1SahVT=Evkp=y)=sJRE_WoxE{2c$bs^84scjk0?oZAe zYMehX%iUW>_TBo)5ruzD2G~$y2iJVaP<^h>Ic4Tu)P2EpdBw`J0>=j|Zqvu)QX(#* z>UcJ3>7Tp^SyTS61uWK`IOf~xxjNkCsV0_XUx5r|K>7mgxJ4SQy=o$RqRC|0>i?X>4PUtt#=s9C-owc)l7R~ffd!e&A`<1KPJf`26AVlf zJo)aarug8Qq?t&$s^2g{70h3-EPHpk;ZruN8E$)rJEPUeDem}kysN9$vc7_NCj&2a z5OeGg2|D?L2;bX2@C$x7a&Kcc_m5uIqJ-U~jqF_>gKKC(N690vyYG*F@7$cfydFWr z;z&>wD>{ytq-O+Wga-#Q7doCZ~Nl&sXVtvONGH7L*zR$H&RQ~aq6>H7^ zbTqO8G9VV`lE85{`13=b^fGAV%lSK~xz{#+hKnW$SVrJD~;Ol+}IFJ4VWBUJS5`}AKx>j3KWm>xmXgb_%w53{b6Bt zei{^Y1GX+c+aOapW?Wv^c^S_dF3APT{)1WJrjRY!jStTm5NfkOqD@M*nsk9Kc`JUi z7@`lMdRbc5ZHh7ef@b&qNO~FCQ2tXr1^o;jGJMSIou-o@LeuA}zw83q)@D3$oq@># zcm(i?w)E6x%6LNKkhlrUnS)iZ$zFv#(wm(3r!?EB-Zu%hZk4a(VFELI_uFB_1d-w; z@2OUWZqG3cTI=okdHY5rouI6b1A-Z5@I+LJH0fltYLL|c{hEs|GLv|RIe_t~pdk_>GvZAaTRCo71yi0#s zfh2Gm)2@!$UmwCZf$^}vu6ChPB*rZj5-3kcr@BdtsCpsu@;l;LL3tU6{~KDQQ)z{k z(VvuJ{pHQ>`-@sKGxy_OOaNurQh_l(*!?rD@3pl}6*q?iCMlyQUu9qqVAugTx6^Ts zF(OdreE*{?Rm3cF%TkKryp(rccr~G$Qy}(@Isb)Zf{1BTUC4XTXy#l*!Y1TeDz7jC zc=;qJ9sg+hO0_&{Yg@)`lHnW&OIjhDgD#<%dW5kC9IO^>Sh^vQ+-O!V`pl~X57F`% zg8N6x|4wXhMWHx>PbQ=~n6~bDhd>H)PZ8@601QZMvZ+a$Eti6e(>dNn)ZT#I(M8HQ zo_BQT*;e)dqa(5^! ztv#HXL>(2ND8l`E2?dg?*FwHH?@3tD@`AFrr)5E#dF%Bd%#Rx3SISYaZ$VI)JhO2Z zn|W0U*WL|h(unN_F(~Eyn_9`Y@>46y+${Vm9S+AhC5+%*t<)RBaY)p<9taDQj z-!h}BR~)g}V=8J_gny~xSQHStDZ|7Os*0bh^fkf%Mu5RC(Q$3Se#@+oQ5-EL$%qR7i*TxkOgQk;peQ}2tv(uKo99!FF{Ep4#V3moY z=BvXt>a8P2IVq``@%vXx+IQi_9UJn0>)wgoaLvJ;t?GeUaxMVRQxuCLys#0PTxy*B{STk8<_)(9xdtu-8-2#n%vWM< znS;-NHmtk)Qv}`gw1RTiLv6_2?5c)Qq2u)!cjc0#(I&jM;cctq(co87*<&+lZyGwf zmPZ2x3m;9*l=F$&oXl4j?Ib64HN&YHjrtm-O%8EmQM2kpEC8-d9J(!&j>;lsjK=R* zf6$5upK8?+XQvXh8(b~pFa7Ku1paut#JqjmZaosK$IFT=2#e0`DG3p--)i^GlwCd$ znud_=Eloc=B#f(tY&a<8Rw}VQ6;h+9JqLvvnsnX6YQ~8YtkenxKe;7f$HsxuscIFB z>(HMwS&UH3cKYk_vg94pA79!Rk6u#fCAXzN62)J@@$rc9!U4G5X-PfyU1XVL3f7T4 zkbk%~0mQ-5xP6xjNEvc6+nWXX)|2X3rUZ?JmW==dRAnLnU&Pz|KF)00-@$!9S_RYm z+%Y#)P7^r_k(*bCl!e`S-JY9HJ>MprP|c7>uwgf^YUVCYjL~0+sq|8vti#V*YzUgV z%*;o+%Mr6JU+bN zcIW@-?Q`@*k1hzO?)E6mVloe(tX9NUMCD)MlI@|P?BzyIipWiB(Ff}wW9d$r{MG^Wisel^`a z*3iZo3U$0N5yZ?C*Y;UcOZh3s1Hrv?2gQFXM3~#^s_ZJ}tW;3Av)wSkkWzse1RJ-E zs=pO`r9^2n-=s&*oIOU8M+-dc&*M8e zX`!cl{Lk`E9jhtR`a?;?h8A0c=p{=c>m`C1Vs4Xis-t4uGO$c5&ojJ_M8R$WE7Y_q z1i7205C%6jDF{kpF33DflD3EL^H!|H-v6WU_*Bvq7hu4mj$;l|U%5+iWZA0xtT2vv z_{5~GnFsYiViqsaxivJV-yp(2-APWK;>uy4fO=tKEzMe{QZNEsemp>}%@wmA@36-do3hJ-Bnp7Ac)^uSqY^-VJxhq1w zBB(?%h7`?6Oc;*1V> zRPJvVip->^n4|#V3A=1A`XnE*Q%?rI9?3AJ?mucPLRk&E^9>KdeWG0I-z01E3~RCW z-AQ}Wcn;J&W<9|=K`cL(^zG5QIathC97(3x)$fLD)sf(?igP4Ru+v+Q?dmTd&z zND|$1!BB2{w}3oG+FCy^IcR_upm8Fi9o(0OjG1b%%|}YQb$($AX)c%?qeK$M9##>nI>uvJCL8kf0w^_lSh@%&eWFg@N z?;A*oytjzAVsI|dfHc>r8Jk+)c%?c_XKMERy)B;mV#(YBhAA0X z@k?Qb=_Lq@O~~uDqTm(;>Pxng)XflnXma|G zX=cAQX1t?xh5EH~q@3YvVQ!qfPf&`ZYq{g8E&g_{VR0Ru3h^QunOn-ewl1*X#i*%Y z48iSTcRj?)7r3ycXDi?$0bT~E`(?n_dBBoI;T~cweB}p5c!+c+q+s3a3z#v9UI`Y9 zxWj5oASkU-wVs50!A%;ddSC}$AF|302@m@$pC;ht5(v&x05nu*o5b%ZMPD*I@af@o z!#ecr_($rl_uS#7FkJqPJF+Uq*W(GFTYRocX@mx&J;#z?c&#-hSgdL}ODAG>n zUZ@5=7g_z4h+cSEUIGi+NI=J(9n{zGDjr>j?zK}#mP|{L5ksEccyEiKRwIhtbuuD< z44L7{v|HAIak8CE2(cQb8a>U{h463eMxS29rt4w@d0!6lMtf2y>}nbZ)oBh05_bia*~}D z9?fLLPLzuT^e*1?1vUVnTm8!gu*Fi5Sqa}0eV-uyR3P|9A#O4L<+QnNshIMZs0YV2 zQ9^Mhzq_RJZQGR$yt4hSG_D!D`31LIV)i-;P31G~{CH7H)If_b@T>F>$zD_^I5p9d zZk=o@YQhER+L~4k4I8PMAyulf9-shu5gZR)kdub;8m<3L=)W{JN2_u08*E1N)DHg5 zgZlu9dJGi6+JDhZp2JiwLHe1lupb9?3x3lrJ|gcPlU){GM7g&i)5#Vz0$O*?v_i&Y+KLW-fVi~gc=r% zPUupg0UXjDZ&RU@0EMsfjUPy?^-`s`$*?ad$`eGJb$NVr<^ z!0Tp=9_`MTcrFV4`8Z;wj{3VEp5BX(ic!ZHIzv;?N;hXzqlx)UGsfCn-UBAWJX0UX znTU-)@FEW$&zcU;P{tt80(_xrZ|++XV*$1=1;RPYu+zJ9}GPyy^03&U4yc-B0aAJ4O_ zRjkTXm|-)A@@{Bw6aF2*eKx2#Ne3ff=s9V}aFuAQ<()H8?g SexlNo03`)=`EM^Q!v6=I+4PM7 literal 0 HcmV?d00001 diff --git a/app/data_chat.py b/app/data_chat.py new file mode 100644 index 0000000..f7c55f2 --- /dev/null +++ b/app/data_chat.py @@ -0,0 +1,36 @@ +import os + +from openai import AzureOpenAI + +# Set up credentials +# NOTE: When running locally, these have to be set in the environment +client = AzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_KEY"), + api_version="2024-02-01", +) + +deployment_name = "sqlai" + + +def send_message(message: str) -> str: + """Send a message to the openai chat completion API and return the response. + + Parameters + ---------- + message : str + The user's message to be sent to the chat completion API. + + Returns + ------- + str + The content of the assistant's response message. + """ + response = client.chat.completions.create( + model=deployment_name, + messages=[ + {"role": "system", "content": "Du bist ein hilfreicher Assistent."}, + {"role": "user", "content": message}, + ], + ) + return response.choices[0].message.content