trytond-babi/table.py

744 lines
26 KiB
Python

import time
import datetime as mdatetime
from datetime import datetime
import logging
import sql
import unidecode
import json
from simpleeval import EvalWithCompoundTypes
from trytond import backend
from trytond.bus import notify
from trytond.transaction import Transaction
from trytond.pool import Pool
from trytond.model import (ModelView, ModelSQL, fields, Unique,
DeactivableMixin, sequence_ordered)
from trytond.exceptions import UserError
from trytond.i18n import gettext
from trytond.pyson import Bool, Eval, PYSONDecoder
from trytond.config import config
from .babi import TimeoutChecker, TimeoutException, FIELD_TYPES, QUEUE_NAME
from .babi_eval import babi_eval
VALID_FIRST_SYMBOLS = 'abcdefghijklmnopqrstuvwxyz'
VALID_NEXT_SYMBOLS = '_0123456789'
VALID_SYMBOLS = VALID_FIRST_SYMBOLS + VALID_NEXT_SYMBOLS
logger = logging.getLogger(__name__)
def convert_to_symbol(text):
if not text:
return 'x'
text = unidecode.unidecode(text)
text = text.lower()
if text[0] not in VALID_FIRST_SYMBOLS:
symbol = '_'
else:
symbol = ''
for x in text:
if not x in VALID_SYMBOLS:
if symbol[-1] == '_':
continue
symbol += '_'
else:
symbol += x
if len(symbol) > 1 and symbol[-1] == '_':
symbol = symbol[:-1]
return symbol
def generate_html_table(records):
table = "<table>"
tag = 'th'
for row in records:
table += "<tr>"
for cell in row:
align = 'right'
if cell is None:
cell = '<i>NULL</i>'
elif isinstance(cell, datetime):
cell = cell.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(cell, mdatetime.date):
cell = cell.strftime('%Y-%m-%d')
elif isinstance(cell, mdatetime.time):
cell = cell.strftime('%H:%M:%S')
elif isinstance(cell, (float, int)):
cell = str(cell)
else:
cell = str(cell)
align = 'left'
table += f"<{tag} align='{align}'>{cell}</{tag}>"
tag = 'td'
table += "</tr>"
table += "</table>"
return table
class Table(DeactivableMixin, ModelSQL, ModelView):
'BABI Table'
__name__ = 'babi.table'
name = fields.Char('Name', required=True)
type = fields.Selection([
(None, ''),
('model', 'Model'),
('table', 'Table'),
('query', 'Query'),
], 'Type', required=True)
internal_name = fields.Char('Internal Name', required=True)
model = fields.Many2One('ir.model', 'Model', states={
'invisible': Eval('type') != 'model',
'required': Eval('type') == 'model',
}, domain=[('babi_enabled', '=', True)])
filter = fields.Many2One('babi.filter', 'Filter', domain=[
('model', '=', Eval('model')),
], states={
'invisible': Eval('type') != 'model',
}, depends=['model'])
fields_ = fields.One2Many('babi.field', 'table', 'Fields')
query = fields.Text('Query', states={
'invisible': ~Eval('type').in_(['query', 'table']),
}, depends=['type'])
timeout = fields.Integer('Timeout', required=True, states={
'invisible': ~Eval('type').in_(['model', 'table']),
}, help='If table '
'calculation should take more than the specified timeout (in seconds) '
'the process will be stopped automatically.')
preview_limit = fields.Integer('Preview Limit', required=True)
preview = fields.Function(fields.Binary('Preview',
filename='preview_filename'), 'get_preview')
preview_filename = fields.Function(fields.Char('Preview Filename'),
'get_preview_filename')
babi_raise_user_error = fields.Boolean('Raise User Error',
help='Will raise a UserError in case of an error in the table.')
compute_error = fields.Text('Compute Error', states={
'invisible': ~Bool(Eval('compute_error')),
}, readonly=True)
crons = fields.One2Many('ir.cron', 'babi_table', 'Schedulers', context={
'babi_table': Eval('id'),
}, depends=['id'])
requires = fields.One2Many('babi.table.dependency', 'required_by',
'Requires', readonly=True)
required_by = fields.One2Many('babi.table.dependency', 'table',
'Required By', readonly=True)
ai_request = fields.Text('AI Request', states={
'invisible': Eval('type') == 'model',
})
ai_response = fields.Text('AI Response', readonly=True, states={
'invisible': Eval('type') == 'model',
})
@staticmethod
def default_timeout():
Config = Pool().get('babi.configuration')
config = Config(1)
return config.default_timeout or 30
@staticmethod
def default_preview_limit():
return 10
@classmethod
def __setup__(cls):
super().__setup__()
cls._order.insert(0, ('name', 'ASC'))
cls._buttons.update({
'ai': {},
'compute': {},
})
@classmethod
def create(cls, vlist):
tables = super().create(vlist)
for table in tables:
table.update_table_dependencies()
return tables
@classmethod
def write(cls, *args):
args = [x.copy() for x in args]
actions = iter(args)
for tables, values in zip(actions, actions):
if 'ai_request' in values and not 'ai_response' in values:
values['ai_response'] = None
if 'internal_name' not in values:
continue
for table in tables:
table._drop()
super().write(*args)
actions = iter(args)
for tables, values in zip(actions, actions):
for table in tables:
table.update_table_dependencies()
if 'internal_name' in values:
table._drop()
@classmethod
def delete(cls, tables):
for table in tables:
table._drop()
super().delete(tables)
def update_table_dependencies(self):
pool = Pool()
Dependency = pool.get('babi.table.dependency')
Dependency.delete(self.requires)
Dependency.delete(self.required_by)
tables = {x.table_name: x for x in self.search([])}
to_save = []
required_tables = self.get_required_table_names()
for name in required_tables:
dependency = Dependency()
dependency.required_by = self
dependency.name = name
dependency.table = tables.get(name)
to_save.append(dependency)
requiredby_tables = self.get_required_by_table_names() - required_tables
for name in requiredby_tables:
dependency = Dependency()
dependency.required_by = tables.get(name)
dependency.name = self.table_name
dependency.table = self
to_save.append(dependency)
Dependency.save(to_save)
def get_required_table_names(self):
if self.type and self.type == 'model':
return set()
query = self.query or ''
tables = {x for x in query.split() if x.startswith('__')}
return tables
def get_required_by_table_names(self):
tables = self.search([
('type', 'in', ['table', 'query']),
('query', 'ilike', '%' + self.table_name + '%'),
])
res = set()
for table in tables:
if self.table_name in table.get_required_table_names():
res.add(table.table_name)
return res
def get_preview(self, name):
start = time.time()
content = None
try:
records = self.execute_query(limit=self.preview_limit)
except Exception as e:
content = str(e).encode('utf-8')
elapsed = time.time() - start
if not content:
table = []
row = [x.internal_name for x in self.fields_]
table.append(row)
for record in records:
table.append(record)
content = '%(table)s<br/>%(elapsed).2fms' % {
'table': generate_html_table(table),
'elapsed': elapsed * 1000,
}
preview = '''<!DOCTYPE html>
<html>
<head>
<style>
table, th, td {
border: 1px solid black;
border-collapse: collapse;
padding: 5px;
}
* {
font-family: monospace;
}
</style>
</head>
<body>%s</body></html>
''' % content
return preview.encode()
def get_preview_filename(self, name):
return self.internal_name + '.html'
@classmethod
def validate(cls, tables):
super(Table, cls).validate(tables)
for table in tables:
table.check_internal_name()
table.check_filter()
def check_internal_name(self):
if not self.internal_name[0] in VALID_FIRST_SYMBOLS:
raise UserError(gettext(
'babi.msg_invalid_internal_name_first_character',
table=self.rec_name, internal_name=self.internal_name))
for symbol in self.internal_name:
if not symbol in VALID_SYMBOLS:
raise UserError(gettext('babi.msg_invalid_internal_name',
table=self.rec_name, internal_name=self.internal_name))
def check_filter(self):
if self.filter and self.filter.parameters:
raise UserError(gettext('babi.msg_filter_with_parameters',
table=self.rec_name))
@fields.depends('name')
def on_change_name(self):
self.internal_name = convert_to_symbol(self.name)
def get_python_filter(self):
if self.filter and self.filter.python_expression:
return self.filter.python_expression
def get_domain_filter(self):
domain = '[]'
if self.filter and self.filter.domain:
domain = self.filter.domain
if '__' in domain:
domain = str(PYSONDecoder().decode(domain))
return eval(domain, {
'datetime': mdatetime,
'false': False,
'true': True,
})
def get_context(self):
if self.filter and self.filter.context:
context = self.replace_parameters(self.filter.context)
ev = EvalWithCompoundTypes(names={}, functions={
'date': lambda x: datetime.strptime(x, '%Y-%m-%d').date(),
'datetime': lambda x: datetime.strptime(x, '%Y-%m-%d'),
})
context = ev.eval(context)
return context
@property
def ai_sql_tables(self):
return {'account_invoice', 'account_invoice_line'}
@classmethod
@ModelView.button
def ai(cls, tables):
cursor = Transaction().connection.cursor()
from openai import OpenAI
client = OpenAI(
organization=config.get('openai', 'organization'),
api_key=config.get('openai', 'api_key')
)
for table in tables:
sqltables = dict.fromkeys(table.ai_sql_tables)
t = sql.Table('columns', schema='information_schema')
query = t.select(t.table_name, t.column_name,
t.data_type)
query.where = (t.table_schema == 'public') & \
(t.table_name.in_(tuple(sqltables.keys())))
cursor.execute(*query)
for table_name, column_name, data_type in cursor.fetchall():
sqltables[table_name] = sqltables[table_name] or {}
sqltables[table_name][column_name] = data_type
request = '''Given the following tables:
%s
Write an SQL query that returns the following information:
%s
Query:''' % (json.dumps(sqltables), table.ai_request)
messages = [{
'role': 'system',
'content': 'Always return an SQL query that is suitable for PostgreSQL',
}, {
'role': 'user',
'content': request,
}]
response = client.chat.completions.create(model="gpt-3.5-turbo",
messages=messages)
if response.choices:
query = response.choices[0].message.content
if not table.query:
table.query = query
table.ai_response = query
table.save()
@classmethod
@ModelView.button
def compute(cls, tables):
with Transaction().set_context(queue_name=QUEUE_NAME):
for table in tables:
cls.__queue__._compute(table)
@property
def table_name(self):
# Add a suffix to the table name to prevent removing production tables
return '__' + self.internal_name
def get_query(self, fields=None, where=None, groupby=None, limit=None):
query = 'SELECT '
if fields:
query += ', '.join(fields) + ' '
else:
query += '* '
if self.type == 'query':
query += 'FROM (%s) AS a ' % self._stripped_query
else:
query += 'FROM %s ' % self.table_name
if where:
where = where.format(**Transaction().context)
query += 'WHERE %s ' % where
if groupby:
query += 'GROUP BY %s ' % ', '.join(groupby) + ' '
if fields:
query += 'ORDER BY %s' % ', '.join(fields)
if limit:
query += 'LIMIT %d' % limit
return query
def execute_query(self, fields=None, where=None, groupby=None, timeout=None,
limit=None):
if timeout is None:
timeout = 10
if (self.type != 'query'
and not backend.TableHandler.table_exist(self.table_name)):
return []
with Transaction().new_transaction() as transaction:
cursor = transaction.connection.cursor()
cursor.execute('SET statement_timeout TO %s;' % int(timeout * 1000))
query = self.get_query(fields, where=where, groupby=groupby,
limit=limit)
cursor.execute(query)
records = cursor.fetchall()
cursor.execute('SET statement_timeout TO 0;')
return records
def timeout_exception(self):
raise TimeoutException
def _compute(self, processed=None):
if processed is None:
processed = []
if self in processed:
seq = ' > '.join([x.rec_name for x in processed] + [self.rec_name])
raise UserError(gettext('babi.msg_circular_dependency',
sequence=seq))
# print('Computing %s.... ' % self.rec_name)
try:
if self.type == 'model':
if not self.fields_:
raise UserError(gettext('babi.msg_table_no_fields',
table=self.name))
if self.filter and self.filter.parameters:
raise UserError(gettext('babi.msg_filter_with_parameters',
table=self.rec_name))
if self.type == 'model':
self._compute_model()
elif self.type == 'table':
self._compute_table()
elif self.type == 'query':
self._compute_query()
for dependency in self.required_by:
dependency.required_by._compute(processed + [self])
except Exception as e:
# In case there is a create view error or SQL typo,
# we do rollback to obtain a value from the gettext()
Transaction().connection.rollback()
notify(gettext('babi.msg_table_failed', table=self.rec_name))
self.compute_error = str(e)
self.save()
return
self.compute_error = None
self.save()
notify(gettext('babi.msg_table_successful', table=self.rec_name))
def update_fields(self, field_names):
pool = Pool()
Field = pool.get('babi.field')
# Update self.fields_
to_save = []
to_delete = []
existing_fields = set([])
for field in self.fields_:
if field.internal_name not in field_names:
to_delete.append(field)
continue
field.sequence = field_names.index(field.internal_name)
existing_fields.add(field.internal_name)
to_save.append(field)
for field_name in (set(field_names) - existing_fields):
field = Field()
field.table = self
field.name = field_name
field.internal_name = field_name
field.sequence = field_names.index(field.internal_name)
to_save.append(field)
Field.save(to_save)
Field.delete(to_delete)
@property
def _stripped_query(self):
if self.query:
return self.query.strip().rstrip(';')
else:
return ''
def _drop(self):
cursor = Transaction().connection.cursor()
if backend.name != 'postgresql':
cursor.execute('DROP TABLE IF EXISTS %s' % self.table_name)
return
cursor.execute("SELECT table_type FROM information_schema.tables "
"WHERE table_name=%s AND table_schema='public'", (self.table_name,))
record = cursor.fetchone()
if not record:
return
if record[0] == 'VIEW':
cursor.execute('DROP VIEW IF EXISTS "%s" CASCADE' % self.table_name)
else:
cursor.execute('DROP TABLE IF EXISTS %s CASCADE' % self.table_name)
def _compute_query(self):
with Transaction().new_transaction() as transaction:
cursor = transaction.connection.cursor()
# We must use a subquery because the _stripped_query may contain a
# LIMIT clause
cursor.execute('SELECT * FROM (%s) AS subquery LIMIT 1' %
self._stripped_query)
field_names = [x[0] for x in cursor.description]
self.update_fields(field_names)
cursor = Transaction().connection.cursor()
self._drop()
cursor.execute('CREATE VIEW "%s" AS %s' % (self.table_name, self._stripped_query))
def _compute_table(self):
with Transaction().new_transaction() as transaction:
cursor = transaction.connection.cursor()
if backend.name == 'postgresql':
cascade = 'CASCADE'
else:
cascade = ''
cursor.execute('DROP TABLE IF EXISTS "%s" %s;' % (self.table_name, cascade))
self._drop()
cursor.execute('CREATE TABLE "%s" AS %s' % (self.table_name,
self._stripped_query))
cursor.execute('SELECT * FROM "%s" LIMIT 1' % self.table_name)
field_names = [x[0] for x in cursor.description]
self.update_fields(field_names)
def _compute_model(self):
Model = Pool().get(self.model.model)
with Transaction().new_transaction() as transaction:
cursor = transaction.connection.cursor()
if backend.name == 'postgresql':
cascade = 'CASCADE'
else:
cascade = ''
cursor.execute('DROP TABLE IF EXISTS "%s" %s' % (self.table_name, cascade))
fields = []
for field in self.fields_:
fields.append('"%s" %s' % (field.internal_name, field.sql_type()))
cursor.execute('CREATE TABLE IF NOT EXISTS "%s" (%s);' % (
self.table_name, ', '.join(fields)))
checker = TimeoutChecker(self.timeout, self.timeout_exception)
domain = self.get_domain_filter()
context = self.get_context()
if not context:
context = {}
else:
assert isinstance(context, dict)
context['_datetime'] = None
# This is needed when execute the wizard to calculate the report, to
# ensure the company rule is used.
context['_check_access'] = True
python_filter = self.get_python_filter()
table = sql.Table(self.table_name)
columns = [sql.Column(table, x.internal_name) for x in self.fields_]
expressions = [x.expression.expression for x in self.fields_]
index = 0
count = 0
offset = 2000
with Transaction().set_context(**context):
try:
records = Model.search(domain, offset=index * offset,
limit=offset)
except Exception as message:
if self.babi_raise_user_error:
raise UserError(gettext(
'babi.create_data_exception',
error=repr(message)))
raise
while records:
checker.check()
logger.info('Calculated %s, %s records in %s seconds'
% (self.model.model, count, checker.elapsed))
to_insert = []
for record in records:
if python_filter:
if not babi_eval(python_filter, record, convert_none=None):
continue
values = []
for expression in expressions:
try:
values.append(babi_eval(expression, record,
convert_none=None))
except Exception as message:
notify(gettext('babi.msg_compute_table_exception',
table=self.name, field=field.name,
record=record.id, error=repr(message)),
priority=1)
if self.babi_raise_user_error:
raise UserError(gettext(
'babi.msg_compute_table_exception',
table=self.name,
field=field.name,
record=record.id,
error=repr(message)))
raise
to_insert.append(values)
cursor.execute(*table.insert(columns=columns, values=to_insert))
index += 1
count += len(records)
with Transaction().set_context(**context):
records = Model.search(domain, offset=index * offset,
limit=offset)
logger.info('Calculated %s, %s records in %s seconds'
% (self.model.model, count, checker.elapsed))
class Field(sequence_ordered(), ModelSQL, ModelView):
'BABI Field'
__name__ = 'babi.field'
table = fields.Many2One('babi.table', 'Table', required=True,
ondelete='CASCADE')
name = fields.Char('Name', required=True)
internal_name = fields.Char('Internal Name', required=True)
expression = fields.Many2One('babi.expression', 'Expression', states={
'invisible': Eval('table_type') != 'model',
'required': Eval('table_type') == 'model'
}, domain=[
('model', '=', Eval('model')),
], depends=['model'])
model = fields.Function(fields.Many2One('ir.model', 'Model'),
'on_change_with_model')
type = fields.Function(fields.Selection(FIELD_TYPES, 'Type'),
'on_change_with_type')
table_type = fields.Function(fields.Selection([
('model', 'Model'),
('table', 'Table'),
('query', 'Query'),
], 'Table Type'), 'on_change_with_table_type')
@fields.depends('expression')
def on_change_with_type(self, name=None):
if self.expression:
return self.expression.ttype
@fields.depends('table', '_parent_table.type')
def on_change_with_table_type(self, name=None):
if self.table:
return self.table.type
@classmethod
def __setup__(cls):
super().__setup__()
t = cls.__table__()
cls._sql_constraints += [
('table_internal_name_uniq', Unique(t, t.table, t.internal_name),
'Field must be unique per Table'),
]
cls.__access__.add('table')
@classmethod
def validate(cls, babi_fields):
super().validate(babi_fields)
for babi_field in babi_fields:
babi_field.check_internal_name()
def sql_type(self):
mapping = {
'char': 'VARCHAR',
'integer': 'INTEGER',
'float': 'FLOAT',
'numeric': 'NUMERIC',
'boolean': 'BOOLEAN',
'many2one': 'INTEGER',
'date': 'DATE',
'datetime': 'DATETIME',
}
return mapping[self.expression.ttype]
def check_internal_name(self):
if not self.internal_name[0] in VALID_FIRST_SYMBOLS:
raise UserError(gettext('babi.msg_invalid_field_internal_name',
field=self.name, internal_name=self.internal_name))
for symbol in self.internal_name:
if not symbol in VALID_SYMBOLS:
raise UserError(gettext('babi.msg_invalid_field_internal_name',
field=self.name, internal_name=self.internal_name))
@fields.depends('name')
def on_change_name(self):
self.internal_name = convert_to_symbol(self.name)
@fields.depends('name', 'expression', methods=['on_change_name'])
def on_change_expression(self):
if self.expression:
self.name = self.expression.name
self.on_change_name()
@fields.depends('table', '_parent_table.model')
def on_change_with_model(self, name=None):
if self.table and self.table.model:
return self.table.model.id
class TableDependency(ModelSQL, ModelView):
'BABI Table Dependency'
__name__ = 'babi.table.dependency'
required_by = fields.Many2One('babi.table', 'Required By', required=True,
ondelete='CASCADE')
name = fields.Char('Name')
table = fields.Many2One('babi.table', 'Requires', ondelete='SET NULL')
@classmethod
def __setup__(cls):
super().__setup__()
cls.__access__.add('table')