trytond-ir_sequence_period/ir.py

242 lines
8.7 KiB
Python

# The COPYRIGHT file at the top level of this repository contains the full
# copyright notices and license terms.
from trytond.model import ModelSQL, ModelView, MatchMixin, fields
from trytond.pool import PoolMeta, Pool
from trytond.pyson import Eval, And, Bool
from trytond.transaction import Transaction
from trytond.i18n import gettext
from trytond.ir.sequence import MissingError
from trytond import backend
from sql import Literal, For
__all__ = ['Sequence', 'SequencePeriod', 'SequenceStrict']
sql_sequence = backend.Database.has_sequence()
class Sequence(metaclass=PoolMeta):
__name__ = 'ir.sequence'
periods = fields.One2Many('ir.sequence.period', 'sequence', 'Periods',
states={'invisible': Eval('type') != 'incremental'},
depends=['type'],
order=[('start_date', 'ASC')])
@classmethod
def write(cls, sequences, values, *args):
super(Sequence, cls).write(sequences, values, *args)
if sql_sequence and not cls._strict:
for sequence in sequences:
for period in sequence.periods:
period.update_sql_sequence()
def get_id(self, _lock=False):
cls = self.__class__
with Transaction().set_context(user=False, _check_access=False):
with Transaction().set_user(0):
try:
sequence = cls(self.id)
except TypeError:
raise MissingError(gettext('ir.msg_sequence_missing'))
date = Transaction().context.get('date')
if sequence.periods:
if not date:
raise MissingError(gettext(
'ir_sequence_period.sequence.msg_missing_date',
sequence=sequence.rec_name))
pattern = sequence._get_period_pattern()
for period in sequence.periods:
if period.match(pattern):
return period.get(_lock=_lock)
raise MissingError(gettext(
'ir_sequence_period.sequence.msg_missing_date',
date=date,
sequence=sequence.rec_name))
return super(Sequence, cls).get(_lock=_lock)
return ''
def _get_period_pattern(self):
return {'date': Transaction().context.get('date')}
class SequencePeriod(ModelSQL, ModelView, MatchMixin):
'''Sequence period'''
__name__ = 'ir.sequence.period'
_strict = False
sequence = fields.Many2One('ir.sequence', 'Sequence', required=True,
ondelete='CASCADE')
start_date = fields.Date('Start date', required=True)
end_date = fields.Date('End date', required=True)
number_next_internal = fields.Integer('Next Number',
states={
'invisible': ~Eval('_parent_sequence', {}).get('type').in_([
'incremental']),
'required': And(Eval('_parent_sequence', {}).get('type').in_(
['incremental']), not sql_sequence),
})
number_next = fields.Function(number_next_internal, 'get_number_next',
'set_number_next')
prefix = fields.Char('Prefix')
suffix = fields.Char('Suffix')
@staticmethod
def default_number_next():
return 1
def get_number_next(self, name):
if self.sequence.type != 'incremental':
return
transaction = Transaction()
if sql_sequence and not self._strict:
return transaction.database.sequence_next_number(
transaction.connection, self._sql_sequence_name)
else:
return self.number_next_internal
@classmethod
def set_number_next(cls, periods, name, value):
super(SequencePeriod, cls).write(periods, {
'number_next_internal': value,
})
@property
def _sql_sequence_name(self):
'Return SQL sequence name'
return '%s_%s' % (self._table, self.id)
@classmethod
def create(cls, vlist):
periods = super(SequencePeriod, cls).create(vlist)
for period, values in zip(periods, vlist):
if sql_sequence and not cls._strict:
period.update_sql_sequence(values.get('number_next',
cls.default_number_next()))
return periods
@classmethod
def write(cls, periods, values, *args):
super(SequencePeriod, cls).write(periods, values, *args)
if sql_sequence and not cls._strict:
actions = iter((periods, values) + args)
for periods, values in zip(actions, actions):
for period in periods:
period.update_sql_sequence(values.get('number_next'))
@classmethod
def delete(cls, periods):
if sql_sequence and not cls._strict:
for period in periods:
period.delete_sql_sequence()
return super(SequencePeriod, cls).delete(periods)
def create_sql_sequence(self, number_next=None):
'Create the SQL sequence'
transaction = Transaction()
if self.sequence.type != 'incremental':
return
if number_next is None:
number_next = self.number_next
if sql_sequence:
transaction.database.sequence_create(transaction.connection,
self._sql_sequence_name, self.sequence.number_increment,
number_next)
def update_sql_sequence(self, number_next=None):
'Update the SQL sequence'
transaction = Transaction()
exist = transaction.database.sequence_exist(
transaction.connection, self._sql_sequence_name)
if self.sequence.type != 'incremental':
if exist:
self.delete_sql_sequence()
return
if not exist:
self.create_sql_sequence(number_next)
return
if number_next is None:
number_next = self.number_next
transaction.database.sequence_update(transaction.connection,
self._sql_sequence_name, self.sequence.number_increment,
number_next)
def delete_sql_sequence(self):
'Delete the SQL sequence'
transaction = Transaction()
if self.sequence.type != 'incremental':
return
transaction.database.sequence_delete(
transaction.connection, self._sql_sequence_name)
def get(self, _lock=False):
Sequence = Pool().get('ir.sequence')
cls = self.__class__
if _lock:
transaction = Transaction()
database = transaction.database
connection = transaction.connection
if not database.has_select_for():
database.lock(connection, self._table)
else:
table = self.__table__()
query = table.select(Literal(1),
where=table.id == self.id,
for_=For('UPDATE', nowait=True))
cursor = connection.cursor()
cursor.execute(*query)
date = Transaction().context.get('date')
return '%s%s%s' % (
Sequence._process(self.prefix or self.sequence.prefix,
date=date),
cls._get_sequence(self),
Sequence._process(self.suffix or self.sequence.suffix,
date=date),
)
@classmethod
def _get_sequence(cls, period):
if period.sequence.type == 'incremental':
if sql_sequence and not cls._strict:
cursor = Transaction().connection.cursor()
cursor.execute('SELECT nextval(\'"%s"\')'
% period._sql_sequence_name)
number_next, = cursor.fetchone()
else:
# Pre-fetch number_next
number_next = period.number_next_internal
cls.write([period], {
'number_next_internal': (number_next
+ period.sequence.number_increment),
})
return '%%0%sd' % period.sequence.padding % number_next
else:
raise NotImplementedError()
def match(self, pattern, match_none=False):
pattern = pattern.copy()
date = pattern.get('date')
if not date:
return False
_match = self.start_date <= date <= self.end_date
_ = pattern.pop('date')
return _match and super().match(pattern, match_none=match_none)
class SequenceStrict(metaclass=PoolMeta):
__name__ = 'ir.sequence.strict'
# needed due to both models share form view
periods = fields.Function(
fields.One2Many('ir.sequence.period', None, 'Periods',
states={'invisible': Bool(True)}),
'get_periods')
def get_periods(self, name=None):
return []