trytond-ir_sequence_period/ir.py
2019-03-04 15:43:35 +01:00

238 lines
8.4 KiB
Python

# The COPYRIGHT file at the top level of this repository contains the full
# copyright notices and license terms.
from trytond.model import ModelSQL, ModelView, MatchMixin, fields
from trytond.pool import PoolMeta, Pool
from trytond.pyson import Eval, And, Bool
from trytond.transaction import Transaction
from trytond import backend
from sql import Literal, For
__all__ = ['Sequence', 'SequencePeriod', 'SequenceStrict']
sql_sequence = backend.get('Database').has_sequence()
class Sequence(metaclass=PoolMeta):
__name__ = 'ir.sequence'
periods = fields.One2Many('ir.sequence.period', 'sequence', 'Periods',
states={'invisible': Eval('type') != 'incremental'},
depends=['type'],
order=[('start_date', 'ASC')])
@classmethod
def __setup__(cls):
super(Sequence, cls).__setup__()
cls._error_messages.update({
'missing_date':
'Must define a Date value in context for Sequence "%s".',
'missing_period':
'Cannot find a valid period for date "%s" in Sequence "%s".'
})
@classmethod
def write(cls, sequences, values, *args):
super(Sequence, cls).write(sequences, values, *args)
if sql_sequence and not cls._strict:
for sequence in sequences:
for period in sequence.periods:
period.update_sql_sequence()
@classmethod
def get_id(cls, domain, _lock=False):
if isinstance(domain, cls):
domain = domain.id
if isinstance(domain, int):
domain = [('id', '=', domain)]
with Transaction().set_context(user=False, _check_access=False):
with Transaction().set_user(0):
try:
sequence, = cls.search(domain, limit=1)
except TypeError:
cls.raise_user_error('missing')
date = Transaction().context.get('date')
if sequence.periods:
if not date:
cls.raise_user_error('missing_date', sequence.rec_name)
pattern = {}
for period in sequence.periods:
if period.match(pattern):
return period.get_id(_lock=_lock)
cls.raise_user_error('missing_period', (date,
sequence.rec_name))
return super(Sequence, cls).get_id(domain, _lock=_lock)
return ''
class SequencePeriod(ModelSQL, ModelView, MatchMixin):
'''Sequence period'''
__name__ = 'ir.sequence.period'
_strict = False
sequence = fields.Many2One('ir.sequence', 'Sequence', required=True,
ondelete='CASCADE')
start_date = fields.Date('Start date', required=True)
end_date = fields.Date('End date', required=True)
number_next_internal = fields.Integer('Next Number',
states={
'invisible': ~Eval('_parent_sequence', {}).get('type').in_([
'incremental']),
'required': And(Eval('_parent_sequence', {}).get('type').in_(
['incremental']), not sql_sequence),
})
number_next = fields.Function(number_next_internal, 'get_number_next',
'set_number_next')
@staticmethod
def default_number_next():
return 1
def get_number_next(self, name):
if self.sequence.type != 'incremental':
return
transaction = Transaction()
if sql_sequence and not self._strict:
return transaction.database.sequence_next_number(
transaction.connection, self._sql_sequence_name)
else:
return self.number_next_internal
@classmethod
def set_number_next(cls, periods, name, value):
super(SequencePeriod, cls).write(periods, {
'number_next_internal': value,
})
@property
def _sql_sequence_name(self):
'Return SQL sequence name'
return '%s_%s' % (self._table, self.id)
@classmethod
def create(cls, vlist):
periods = super(SequencePeriod, cls).create(vlist)
for period, values in zip(periods, vlist):
if sql_sequence and not cls._strict:
period.update_sql_sequence(values.get('number_next',
cls.default_number_next()))
return periods
@classmethod
def write(cls, periods, values, *args):
super(SequencePeriod, cls).write(periods, values, *args)
if sql_sequence and not cls._strict:
actions = iter((periods, values) + args)
for periods, values in zip(actions, actions):
for period in periods:
period.update_sql_sequence(values.get('number_next'))
@classmethod
def delete(cls, periods):
if sql_sequence and not cls._strict:
for period in periods:
period.delete_sql_sequence()
return super(SequencePeriod, cls).delete(periods)
def create_sql_sequence(self, number_next=None):
'Create the SQL sequence'
transaction = Transaction()
if self.sequence.type != 'incremental':
return
if number_next is None:
number_next = self.number_next
if sql_sequence:
transaction.database.sequence_create(transaction.connection,
self._sql_sequence_name, self.sequence.number_increment,
number_next)
def update_sql_sequence(self, number_next=None):
'Update the SQL sequence'
transaction = Transaction()
exist = transaction.database.sequence_exist(
transaction.connection, self._sql_sequence_name)
if self.sequence.type != 'incremental':
if exist:
self.delete_sql_sequence()
return
if not exist:
self.create_sql_sequence(number_next)
return
if number_next is None:
number_next = self.number_next
transaction.database.sequence_update(transaction.connection,
self._sql_sequence_name, self.sequence.number_increment,
number_next)
def delete_sql_sequence(self):
'Delete the SQL sequence'
transaction = Transaction()
if self.sequence.type != 'incremental':
return
transaction.database.sequence_delete(
transaction.connection, self._sql_sequence_name)
def get_id(self, _lock=False):
Sequence = Pool().get('ir.sequence')
if _lock:
transaction = Transaction()
database = transaction.database
connection = transaction.connection
if not database.has_select_for():
database.lock(connection, self._table)
else:
table = self.__table__()
query = table.select(Literal(1),
where=table.id == self.id,
for_=For('UPDATE', nowait=True))
cursor = connection.cursor()
cursor.execute(*query)
date = Transaction().context.get('date')
return '%s%s%s' % (
Sequence._process(self.sequence.prefix, date=date),
self._get_sequence(),
Sequence._process(self.sequence.suffix, date=date),
)
def _get_sequence(self):
if self.sequence.type == 'incremental':
if sql_sequence and not self._strict:
cursor = Transaction().connection.cursor()
cursor.execute('SELECT nextval(\'"%s"\')'
% self._sql_sequence_name)
number_next, = cursor.fetchone()
else:
# Pre-fetch number_next
number_next = self.number_next_internal
self.write([self], {
'number_next_internal': (number_next
+ self.sequence.number_increment),
})
return '%%0%sd' % self.sequence.padding % number_next
else:
raise NotImplementedError()
def match(self, pattern):
date = Transaction().context.get('date')
if not date:
return False
return self.start_date <= date <= self.end_date
class SequenceStrict(metaclass=PoolMeta):
__name__ = 'ir.sequence.strict'
# needed due to both models share form view
periods = fields.Function(
fields.One2Many('ir.sequence.period', None, 'Periods',
states={'invisible': Bool(True)}),
'get_periods')
def get_periods(self, name=None):
return []