Replace assert by AssertionError

This commit is contained in:
Sergio Morillo 2019-04-29 17:01:31 +02:00
parent 60add29c51
commit fddf107b78
4 changed files with 47 additions and 32 deletions

View file

@ -20,3 +20,8 @@
class RetrofixException(Exception):
pass
def raise_assert_error(value, message):
if not value:
raise AssertionError(message)

View file

@ -33,7 +33,7 @@ from datetime import datetime
from stdnum.es.ccc import is_valid
from .formatting import format_string, format_number
from .exception import RetrofixException
from .exception import RetrofixException, raise_assert_error
__all__ = ['Field', 'Char', 'Const', 'Account', 'Number', 'Numeric', 'Integer',
'Date', 'Selection', 'Boolean', 'SIGN_DEFAULT', 'SIGN_12', 'SIGN_N',
@ -47,9 +47,10 @@ class Field(object):
self._size = None
def set_from_file(self, value):
assert len(value) == self._size, ('Invalid length of field "%s". '
raise_assert_error(len(value) == self._size, (
'Invalid length of field "%s". '
'Expected "%d" but got "%d".' % (self._name, self._size,
len(value)))
len(value))))
return value
def get_for_file(self, value):
@ -87,8 +88,9 @@ class Const(Char):
self._const = const
def set_from_file(self, value):
assert value == self._const, ('Invalid value "%s" in Const field '
'"%s". Expected "%s".' % (value, self._name, self._const))
raise_assert_error(value == self._const, (
'Invalid value "%s" in Const field '
'"%s". Expected "%s".' % (value, self._name, self._const)))
return super(Const, self).set_from_file(value)
def get_for_file(self, value):
@ -98,8 +100,8 @@ class Const(Char):
return self._const
def set(self, value):
assert value == self._const, ('Invalid value for field "%s"'
% self._name)
raise_assert_error(value == self._const, ('Invalid value for field "%s"'
% self._name))
return super(Const, self).set(value)
@ -130,15 +132,15 @@ class Number(Char):
self._align = align
def set_from_file(self, value):
assert re.match('[0-9]*$', value), (
'Non-number value "%s" in field "%s"' % (value, self._name))
raise_assert_error(re.match('[0-9]*$', value), (
'Non-number value "%s" in field "%s"' % (value, self._name)))
return super(Number, self).set_from_file(value)
def set(self, value):
if value is None:
value = ''
assert re.match('[0-9]*$', value), (
'Non-number value "%s" in field "%s"' % (value, self._name))
raise_assert_error(re.match('[0-9]*$', value), (
'Non-number value "%s" in field "%s"' % (value, self._name)))
l = self._size - len(value)
if self._align == 'right':
@ -171,8 +173,9 @@ class Numeric(Field):
if self._sign == SIGN_N_BLANK:
return ' ' if value >= Decimal('0.0') else 'N'
if self._sign == SIGN_POSITIVE:
assert value >= Decimal('0.0'), ('Field "%s" must be >= 0.0 but '
'got "%.2f"' % (self._name, value))
raise_assert_error(value >= Decimal('0.0'), (
'Field "%s" must be >= 0.0 but '
'got "%.2f"' % (self._name, value)))
return ''
def set_from_file(self, value):
@ -193,9 +196,9 @@ class Numeric(Field):
value = Decimal('0')
sign = self.get_sign(value)
length = self._size - len(sign)
assert length > 0, ('Number formatting error. Field size '
'"%d" but only "%d" characters left for formatting field "%s".') % (
self._size, length, self._name)
raise_assert_error(length > 0, ('Number formatting error. Field size '
'"%d" but only "%d" characters left for formatting field "%s".'
) % (self._size, length, self._name))
return sign + format_number(abs(value), length, self._decimals)
def set(self, value):
@ -238,7 +241,7 @@ class Date(Field):
def set(self, value):
if value is not None:
assert value, datetime
raise_assert_error(value, datetime)
return super(Date, self).set(value)
@ -253,13 +256,15 @@ class Selection(Char):
return super(Selection, self).get_for_file(value)
def set_from_file(self, value):
assert value in self._keys, ('Value "%s" not found in selection field '
'"%s". Expected one of: %s' % (value, self._name, self._keys))
raise_assert_error(value in self._keys, (
'Value "%s" not found in selection field '
'"%s". Expected one of: %s' % (value, self._name, self._keys)))
return super(Selection, self).set_from_file(value)
def set(self, value):
assert value in self._values, ('Value "%s" not found in selection field '
'"%s". Expected one of: %s' % (value, self._name, self._values))
raise_assert_error(value in self._values, (
'Value "%s" not found in selection field '
'"%s". Expected one of: %s' % (value, self._name, self._values)))
value = self._values[value]
return super(Selection, self).set(value)

View file

@ -18,6 +18,7 @@
#
##############################################################################
from decimal import Decimal
from .exception import raise_assert_error
def format_string(text, length, fill=' ', align='<'):
@ -30,7 +31,8 @@ def format_string(text, length, fill=' ', align='<'):
if len(text) > length:
text = text[:length]
text = '{0:{1}{2}{3}s}'.format(text, fill, align, length)
assert len(text) == length, 'Formatted string must match the given length'
raise_assert_error(len(text) == length,
'Formatted string must match the given length')
return text
@ -41,6 +43,7 @@ def format_number(number, size, decimals=0):
length += 1
text = '{0:{1}{2}{3}.{4}f}'.format(number, '0', '>', length, decimals)
text = text.replace('.', '')
assert len(text) == size, ('Formatted number "%s" must match the given '
'length "%d". Got: "%s".' % (number, size, text))
raise_assert_error(len(text) == size, (
'Formatted number "%s" must match the given '
'length "%d". Got: "%s".' % (number, size, text)))
return text

View file

@ -20,7 +20,7 @@
from .fields import Field
from .exception import RetrofixException
from .exception import RetrofixException, raise_assert_error
BLANK = ' '
@ -49,7 +49,7 @@ class Record(object):
field = field()
field._size = size
field._name = name
assert name not in keys, 'Duplicate field name "%s".' % name
raise_assert_error(name not in keys, 'Duplicate field name "%s".' % name)
keys.add(name)
self._fields[name] = field
@ -60,13 +60,13 @@ class Record(object):
self._values[name] = self._fields[name].set_from_file(value)
def __getattr__(self, name):
assert name in self._fields, 'Field "%s" does not exist.' % name
raise_assert_error(name in self._fields, 'Field "%s" does not exist.' % name)
return self._fields[name].get(self._values.get(name))
def __setattr__(self, name, value):
if name.startswith('_'):
return super(Record, self).__setattr__(name, value)
assert name in self._fields, 'Field "%s" does not exist.' % name
raise_assert_error(name in self._fields, 'Field "%s" does not exist.' % name)
self._values[name] = self._fields[name].set(value)
def load(self, line, first_position=1):
@ -102,12 +102,14 @@ class Record(object):
value = self.get_for_file(name)
assert len(value) == length, ('Field "%s" should be of size "%d" '
raise_assert_error(len(value) == length, (
'Field "%s" should be of size "%d" '
'but got "%d" on record "%s".' % (name, length, len(value),
str(self)))
assert start >= current_position, ('Error writing field "%s". '
str(self))))
raise_assert_error(start >= current_position, (
'Error writing field "%s". '
'Start: %d, Current Position: %d' % (name, start,
current_position))
current_position)))
text += BLANK * (start - current_position)
text += value
current_position = len(text)