Add AI completion for SQL queries.
This commit is contained in:
parent
f0d0a9a53c
commit
560d758fb7
67
table.py
67
table.py
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
import logging
|
||||
import sql
|
||||
import unidecode
|
||||
import json
|
||||
from simpleeval import EvalWithCompoundTypes
|
||||
from trytond import backend
|
||||
from trytond.bus import notify
|
||||
|
@ -14,6 +15,7 @@ from trytond.model import (ModelView, ModelSQL, fields, Unique,
|
|||
from trytond.exceptions import UserError
|
||||
from trytond.i18n import gettext
|
||||
from trytond.pyson import Bool, Eval, PYSONDecoder
|
||||
from trytond.config import config
|
||||
from .babi import TimeoutChecker, TimeoutException, FIELD_TYPES, QUEUE_NAME
|
||||
from .babi_eval import babi_eval
|
||||
|
||||
|
@ -95,7 +97,6 @@ class Table(DeactivableMixin, ModelSQL, ModelView):
|
|||
fields_ = fields.One2Many('babi.field', 'table', 'Fields')
|
||||
query = fields.Text('Query', states={
|
||||
'invisible': ~Eval('type').in_(['query', 'table']),
|
||||
'required': Eval('type').in_(['query', 'table']),
|
||||
}, depends=['type'])
|
||||
timeout = fields.Integer('Timeout', required=True, states={
|
||||
'invisible': ~Eval('type').in_(['model', 'table']),
|
||||
|
@ -119,6 +120,12 @@ class Table(DeactivableMixin, ModelSQL, ModelView):
|
|||
'Requires', readonly=True)
|
||||
required_by = fields.One2Many('babi.table.dependency', 'table',
|
||||
'Required By', readonly=True)
|
||||
ai_request = fields.Text('AI Request', states={
|
||||
'invisible': Eval('type') == 'model',
|
||||
})
|
||||
ai_response = fields.Text('AI Response', readonly=True, states={
|
||||
'invisible': Eval('type') == 'model',
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def default_timeout():
|
||||
|
@ -135,6 +142,7 @@ class Table(DeactivableMixin, ModelSQL, ModelView):
|
|||
super().__setup__()
|
||||
cls._order.insert(0, ('name', 'ASC'))
|
||||
cls._buttons.update({
|
||||
'ai': {},
|
||||
'compute': {},
|
||||
})
|
||||
|
||||
|
@ -147,8 +155,11 @@ class Table(DeactivableMixin, ModelSQL, ModelView):
|
|||
|
||||
@classmethod
|
||||
def write(cls, *args):
|
||||
args = [x.copy() for x in args]
|
||||
actions = iter(args)
|
||||
for tables, values in zip(actions, actions):
|
||||
if 'ai_request' in values and not 'ai_response' in values:
|
||||
values['ai_response'] = None
|
||||
if 'internal_name' not in values:
|
||||
continue
|
||||
for table in tables:
|
||||
|
@ -308,6 +319,60 @@ class Table(DeactivableMixin, ModelSQL, ModelView):
|
|||
context = ev.eval(context)
|
||||
return context
|
||||
|
||||
@property
|
||||
def ai_sql_tables(self):
|
||||
return {'account_invoice', 'account_invoice_line'}
|
||||
|
||||
@classmethod
|
||||
@ModelView.button
|
||||
def ai(cls, tables):
|
||||
pool = Pool()
|
||||
Model = pool.get('ir.model')
|
||||
|
||||
cursor = Transaction().connection.cursor()
|
||||
|
||||
import openai
|
||||
openai.organization = config.get('openai', 'organization')
|
||||
openai.api_key = config.get('openai', 'api_key')
|
||||
for table in tables:
|
||||
sqltables = dict.fromkeys(table.ai_sql_tables)
|
||||
|
||||
t = sql.Table('columns', schema='information_schema')
|
||||
query = t.select(t.table_name, t.column_name,
|
||||
t.data_type)
|
||||
query.where = (t.table_schema == 'public') & \
|
||||
(t.table_name.in_(tuple(sqltables.keys())))
|
||||
cursor.execute(*query)
|
||||
for table_name, column_name, data_type in cursor.fetchall():
|
||||
sqltables[table_name] = sqltables[table_name] or {}
|
||||
sqltables[table_name][column_name] = data_type
|
||||
|
||||
request = '''Given the following tables:
|
||||
|
||||
%s
|
||||
|
||||
Write an SQL query that returns the following information:
|
||||
|
||||
%s
|
||||
|
||||
Query:''' % (json.dumps(sqltables), table.ai_request)
|
||||
messages = [{
|
||||
'role': 'system',
|
||||
'content': 'Always return an SQL query that is suitable for PostgreSQL',
|
||||
}, {
|
||||
'role': 'user',
|
||||
'content': request,
|
||||
}]
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages)
|
||||
if response.choices:
|
||||
query = response.choices[0].message.content
|
||||
if not table.query:
|
||||
table.query = query
|
||||
table.ai_response = query
|
||||
table.save()
|
||||
|
||||
@classmethod
|
||||
@ModelView.button
|
||||
def compute(cls, tables):
|
||||
|
|
|
@ -51,6 +51,11 @@
|
|||
<field name="string">Compute</field>
|
||||
<field name="model" search="[('model', '=', 'babi.table')]"/>
|
||||
</record>
|
||||
<record model="ir.model.button" id="babi_table_ai_button">
|
||||
<field name="name">ai</field>
|
||||
<field name="string">AI</field>
|
||||
<field name="model" search="[('model', '=', 'babi.table')]"/>
|
||||
</record>
|
||||
|
||||
<menuitem id="menu_babi_table" parent="menu_configuration" action="act_babi_table" sequence="30"/>
|
||||
|
||||
|
|
|
@ -34,10 +34,17 @@
|
|||
<field name="babi_raise_user_error"/>
|
||||
<label name="timeout"/>
|
||||
<field name="timeout"/>
|
||||
</page>
|
||||
</page>
|
||||
<page id="dependencies" string="Dependencies">
|
||||
<field name="requires" colspan="2" height="550"/>
|
||||
<field name="required_by" colspan="2" height="550"/>
|
||||
</page>
|
||||
<page name="ai_request" string="AI" col="3">
|
||||
<separator name="ai_request" string="What do you want to show?" colspan="2"/>
|
||||
<separator name="ai_reponse" string="AI Response:"/>
|
||||
<field name="ai_request" height="450"/>
|
||||
<button name="ai" string="AI >>" xexpand="0" xfill="0"/>
|
||||
<field name="ai_response" height="450"/>
|
||||
</page>
|
||||
</notebook>
|
||||
</form>
|
||||
|
|
Loading…
Reference in New Issue