Added tests, improved serialization

This commit is contained in:
Mariano Sorgente 2019-08-05 14:38:16 +09:00
parent 212d5da74a
commit 784c8b4d8a
13 changed files with 283 additions and 163 deletions

View File

@ -11,8 +11,7 @@ and decoded by src.util.cbor_serialization.
filenames = ["src.protocols.farmer_protocol",
"src.protocols.plotter_protocol"]
mods = [importlib.import_module(filename)
for filename in filenames]
mods = [importlib.import_module(filename) for filename in filenames]
custom_tags_separate = [dict([(cls, cls.__tag__) for _, cls in mod.__dict__.items()
if hasattr(cls, "__tag__")]) for mod in mods]

View File

@ -1,4 +1,5 @@
from src.util.streamable import streamable, StreamableOptional
from typing import Optional
from src.util.streamable import streamable
from src.types.block_header import BlockHeader
from src.types.challenge import Challenge
from src.types.proof_of_space import ProofOfSpace
@ -8,7 +9,7 @@ from src.types.proof_of_time import ProofOfTime, ProofOfTimeOutput
@streamable
class FoliageBlock:
proof_of_space: ProofOfSpace
proof_of_time_output: StreamableOptional(ProofOfTimeOutput)
proof_of_time: StreamableOptional(ProofOfTime)
proof_of_time_output: Optional[ProofOfTimeOutput]
proof_of_time: Optional[ProofOfTime]
challenge: Challenge
header: BlockHeader

View File

@ -1,5 +1,6 @@
from typing import List
from blspy import PublicKey
from src.util.streamable import streamable, StreamableList
from src.util.streamable import streamable
from src.util.ints import uint8
@ -8,4 +9,4 @@ class ProofOfSpace:
pool_pubkey: PublicKey
plot_pubkey: PublicKey
size: uint8
proof: StreamableList(uint8)
proof: List[uint8]

View File

@ -1,4 +1,5 @@
from src.util.streamable import streamable, StreamableList
from typing import List
from src.util.streamable import streamable
from src.types.sized_bytes import bytes32
from src.types.classgroup import ClassgroupElement
from src.util.ints import uint8, uint64
@ -15,4 +16,4 @@ class ProofOfTimeOutput:
class ProofOfTime:
output: ProofOfTimeOutput
witness_type: uint8
witness: StreamableList(ClassgroupElement)
witness: List[ClassgroupElement]

View File

@ -5,7 +5,7 @@ from .bin_methods import BinMethods
def make_sized_bytes(size):
"""
Create a streamable type that subclasses "hexbytes" but requires instances
Create a streamable type that subclasses "bytes" but requires instances
to be a certain, fixed size.
"""
name = "bytes%d" % size

View File

@ -29,7 +29,8 @@ def default_encoder(encoder, value: Any):
def tag_hook(decoder, tag, shareable_index=None):
"""
If we find a custom tag, decode this. Otherwise, just return the tag (no decoding).
If we find a custom tag, decode this by calling the constructor with the
member data items. Otherwise, just return the tag (no decoding).
"""
for (cls, cls_tag) in custom_tags.items():
if tag.tag == cls_tag:

View File

@ -1,6 +1,5 @@
import dataclasses
from typing import Any
from src.util.type_checking import ArgTypeChecker
from typing import Any, Type
from src.util.type_checking import strictdataclass
def cbor_message(tag: int):
@ -8,8 +7,7 @@ def cbor_message(tag: int):
Decorator, converts a class into a data class, which checks all arguments to make sure
they are the right type.
"""
def apply_cbor_code(cls: Any):
cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True)
return type(cls.__name__, (cls1, ArgTypeChecker), {'__tag__': tag})
def apply_cbor_code(cls: Any) -> Type:
cls1 = strictdataclass(cls=cls)
return type(cls.__name__, (cls1,), {'__tag__': tag})
return apply_cbor_code

View File

@ -1,9 +1,9 @@
import dataclasses
from __future__ import annotations
from blspy import PublicKey, Signature, PrependSignature
from typing import Type, BinaryIO, get_type_hints, Any, Optional, List
from src.util.ints import uint32, uint8
from src.util.type_checking import ArgTypeChecker
from typing import Type, BinaryIO, get_type_hints, Any, List
from src.util.type_checking import strictdataclass, is_type_List, is_type_SpecificOptional
from src.util.bin_methods import BinMethods
from src.util.ints import uint32
# TODO: Remove hack, this allows streaming these objects from binary
@ -16,134 +16,80 @@ size_hints = {
def streamable(cls: Any):
"""
This is a decorator for class definitions. It applies the dataclasses.dataclass
decorator, and also allows fields to be cast to their expected type. The resulting
class also gets parse and stream for free, as long as all its constituent elements
have it.
This is a decorator for class definitions. It applies the strictdataclass decorator,
which checks all types at construction. It also defines a simple serialization format,
and adds parse, from bytes, stream, and serialize methods.
Serialization format:
* Each field is serialized in order, by calling parse/serialize.
* For Lists, there is a 4 byte prefix for the list length.
* For Optionals, there is a one byte prefix, 1 iff object is present, 0 iff not.
All of the constituents must have parse/from_bytes, and stream/serialize and therefore
be of fixed size. For example, int cannot be a constituent since it is not a fixed size,
whereas uint32 can be.
This class is used for deterministic serialization and hashing, for consensus critical
objects such as the block header.
"""
class _Local:
@classmethod
def parse_one_item(cls: Type[cls.__name__], f_type: Type, f: BinaryIO):
if is_type_List(f_type):
inner_type: Type = f_type.__args__[0]
full_list: List[inner_type] = []
assert inner_type != List.__args__[0]
list_size: uint32 = int.from_bytes(f.read(4), "big")
for list_index in range(list_size):
full_list.append(cls.parse_one_item(inner_type, f))
return full_list
if is_type_SpecificOptional(f_type):
inner_type: Type = f_type.__args__[0]
is_present: bool = f.read(1) == bytes([1])
if is_present:
return cls.parse_one_item(inner_type, f)
else:
return None
if hasattr(f_type, "parse"):
return f_type.parse(f)
if hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]:
return f_type.from_bytes(f.read(size_hints[f_type.__name__]))
else:
raise RuntimeError(f"Type {f_type} does not have parse")
@classmethod
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__:
values = []
for f_name, f_type in get_type_hints(cls).items():
if hasattr(f_type, "parse"):
values.append(f_type.parse(f))
elif hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]:
values.append(f_type.from_bytes(f.read(size_hints[f_type.__name__])))
else:
raise NotImplementedError
for _, f_type in get_type_hints(cls).items():
values.append(cls.parse_one_item(f_type, f))
return cls(*values)
def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
if is_type_List(f_type):
assert is_type_List(type(item))
f.write(uint32(len(item)).to_bytes(4, "big"))
inner_type: Type = f_type.__args__[0]
assert inner_type != List.__args__[0]
for element in item:
self.stream_one_item(inner_type, element, f)
elif is_type_SpecificOptional(f_type):
inner_type: Type = f_type.__args__[0]
if item is None:
f.write(bytes([0]))
else:
f.write(bytes([1]))
self.stream_one_item(inner_type, item, f)
elif hasattr(f_type, "stream"):
item.stream(f)
elif hasattr(f_type, "serialize"):
f.write(item.serialize())
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")
def stream(self, f: BinaryIO) -> None:
for f_name, f_type in get_type_hints(self).items():
v = getattr(self, f_name)
if hasattr(f_type, "stream"):
v.stream(f)
elif hasattr(f_type, "serialize"):
f.write(v.serialize())
else:
raise NotImplementedError(f"can't stream {v}, {f_name}")
self.stream_one_item(f_type, getattr(self, f_name), f)
cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True)
return type(cls.__name__, (cls1, BinMethods, ArgTypeChecker, _Local), {})
def StreamableList(the_type):
"""
This creates a streamable homogenous list of the given streamable object. It has
a 32-bit unsigned prefix length, so lists are limited to a length of 2^32 - 1.
"""
cls_name = "%sList" % the_type.__name__
def __init__(self, items: List[the_type]):
self._items = tuple(items)
def __iter__(self):
return iter(self._items)
@classmethod
def parse(cls: Type[cls_name], f: BinaryIO) -> cls_name:
count = uint32.parse(f)
items = []
for _ in range(count):
if hasattr(the_type, "parse"):
items.append(the_type.parse(f))
elif hasattr(the_type, "from_bytes") and size_hints[the_type.__name__]:
items.append(the_type.from_bytes(f.read(size_hints[the_type.__name__])))
else:
raise ValueError("wrong type for %s" % the_type)
return cls(items)
def stream(self, f: BinaryIO) -> None:
count = uint32(len(self._items))
count.stream(f)
for item in self._items:
if hasattr(type(item), "stream"):
item.stream(f)
elif hasattr(type(item), "serialize"):
f.write(item.serialize())
else:
raise NotImplementedError(f"can't stream {type(item)}")
def __str__(self):
return str(self._items)
def __repr__(self):
return repr(self._items)
namespace = dict(
__init__=__init__, __iter__=__iter__, parse=parse,
stream=stream, __str__=__str__, __repr__=__repr__)
streamable_list_type = type(cls_name, (BinMethods,), namespace)
return streamable_list_type
def StreamableOptional(the_type):
"""
This creates a streamable optional of the given streamable object. It has
a 1 byte big-endian prefix which is equal to 1 if the element is there,
and 0 if the element is not there.
"""
cls_name = "%sOptional" % the_type.__name__
def __init__(self, item: Optional[the_type]):
self._item = item
@classmethod
def parse(cls: Type[cls_name], f: BinaryIO) -> cls_name:
is_present: bool = (uint8.parse(f) == 1)
item: Optional[the_type] = None
if is_present:
if hasattr(the_type, "parse"):
item = the_type.parse(f)
elif hasattr(the_type, "from_bytes") and size_hints[the_type.__name__]:
item = the_type.from_bytes(f.read(size_hints[the_type.__name__]))
else:
raise ValueError("wrong type for %s" % the_type)
return cls(item)
def stream(self, f: BinaryIO) -> None:
is_present: uint8 = uint8(1) if self._item else uint8(0)
is_present.stream(f)
if is_present == 1:
if hasattr(type(self._item), "stream"):
self._item.stream(f)
elif hasattr(type(self._item), "serialize"):
f.write(self._item.serialize())
else:
raise NotImplementedError(f"can't stream {type(self._item)}")
def __str__(self):
return str(self._item)
def __repr__(self):
return repr(self._item)
namespace = dict(
__init__=__init__, parse=parse,
stream=stream, __str__=__str__, __repr__=__repr__)
streamable_optional_type = type(cls_name, (BinMethods,), namespace)
return streamable_optional_type
cls1 = strictdataclass(cls)
return type(cls.__name__, (cls1, BinMethods, _Local), {})

View File

@ -1,25 +1,59 @@
from typing import Any, Type, get_type_hints
from typing import Any, Type, get_type_hints, List, Union
import dataclasses
class ArgTypeChecker:
def parse_item(self, a: Any, f_name: str, f_type: Type) -> Any:
if hasattr(f_type, "__origin__") and f_type.__origin__ == list:
return [self.parse_item(el, f_type.__args__[0].__name__, f_type.__args__[0]) for el in a]
if not isinstance(a, f_type):
try:
a = f_type.from_bytes(a)
except TypeError:
a = f_type(a)
if not isinstance(a, f_type):
raise ValueError("wrong type for %s" % f_name)
return a
def __init__(self, *args):
fields = get_type_hints(self)
la, lf = len(args), len(fields)
if la != lf:
raise ValueError("got %d and expected %d args" % (la, lf))
for a, (f_name, f_type) in zip(args, fields.items()):
object.__setattr__(self, f_name, self.parse_item(a, f_name, f_type))
def is_type_List(f_type: Type) -> bool:
return (hasattr(f_type, "__origin__") and f_type.__origin__ == list) or f_type == list
def is_type_SpecificOptional(f_type) -> bool:
"""
Returns true for types such as Optional[T], but not Optional, or T.
"""
return (hasattr(f_type, "__origin__") and f_type.__origin__ == Union
and f_type.__args__[1]() is None)
def strictdataclass(cls: Any):
class _Local():
"""
Dataclass where all fields must be type annotated, and type checking is performed
at initialization, even recursively through Lists. Non-annotated fields are ignored.
Also, for any fields which have a type with .from_bytes(bytes) or constructor(bytes),
bytes can be passed in and the type can be constructed.
"""
def parse_item(self, item: Any, f_name: str, f_type: Type) -> Any:
if is_type_List(f_type):
collected_list: f_type = []
inner_type: Type = f_type.__args__[0]
assert inner_type != List.__args__[0]
if not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a list.")
for el in item:
collected_list.append(self.parse_item(el, f_name, inner_type))
return collected_list
if is_type_SpecificOptional(f_type):
if item is None:
return None
else:
inner_type: Type = f_type.__args__[0]
return self.parse_item(item, f_name, inner_type)
if not isinstance(item, f_type):
try:
item = f_type(item)
except (TypeError, AttributeError, ValueError):
item = f_type.from_bytes(item)
if not isinstance(item, f_type):
raise ValueError(f"Wrong type for {f_name}")
return item
def __init__(self, *args):
fields = get_type_hints(self)
la, lf = len(args), len(fields)
if la != lf:
raise ValueError("got %d and expected %d args" % (la, lf))
for a, (f_name, f_type) in zip(args, fields.items()):
object.__setattr__(self, f_name, self.parse_item(a, f_name, f_type))
cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True)
return type(cls.__name__, (cls1, _Local), {})

0
tests/__init__.py Normal file
View File

0
tests/util/__init__.py Normal file
View File

View File

@ -0,0 +1,50 @@
import unittest
from typing import List, Optional
from src.util.streamable import streamable
from src.util.ints import uint32
class TestStreamable(unittest.TestCase):
def test_basic(self):
@streamable
class TestClass:
a: uint32
b: uint32
c: List[uint32]
d: List[List[uint32]]
e: Optional[uint32]
f: Optional[uint32]
a = TestClass(24, 352, [1, 2, 4], [[1, 2, 3], [3, 4]], 728, None)
b: bytes = a.serialize()
assert a == TestClass.from_bytes(b)
def test_variablesize(self):
@streamable
class TestClass2:
a: uint32
b: uint32
c: str
a = TestClass2(1, 2, "3")
try:
a.serialize()
assert False
except NotImplementedError:
pass
@streamable
class TestClass3:
a: int
b = TestClass3(1)
try:
b.serialize()
assert False
except NotImplementedError:
pass
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,89 @@
import unittest
from src.util.type_checking import is_type_List, is_type_SpecificOptional, strictdataclass
from src.util.ints import uint8
from typing import List, Dict, Tuple, Optional
class TestIsTypeList(unittest.TestCase):
def test_basic_list(self):
a = [1, 2, 3]
assert is_type_List(type(a))
assert is_type_List(List)
assert is_type_List(List[int])
assert is_type_List(List[uint8])
assert is_type_List(list)
assert not is_type_List(Tuple)
assert not is_type_List(tuple)
assert not is_type_List(dict)
def test_not_lists(self):
assert not is_type_List(Dict)
class TestIsTypeSpecificOptional(unittest.TestCase):
def test_basic_optional(self):
assert is_type_SpecificOptional(Optional[int])
assert is_type_SpecificOptional(Optional[Optional[int]])
assert not is_type_SpecificOptional(List[int])
class TestStrictClass(unittest.TestCase):
def test_StrictDataClass(self):
@strictdataclass
class TestClass1:
a: int
b: str
good: TestClass1 = TestClass1(24, "!@12")
assert TestClass1.__name__ == "TestClass1"
assert good
assert good.a == 24
assert good.b == "!@12"
good2 = TestClass1(52, bytes([1, 2, 3]))
assert good2.b == str(bytes([1, 2, 3]))
def test_StrictDataClassBad(self):
@strictdataclass
class TestClass2:
a: int
b = 0
assert TestClass2(25)
try:
TestClass2(1, 2)
assert False
except ValueError:
pass
def test_StrictDataClassLists(self):
@strictdataclass
class TestClass:
a: List[int]
b: List[List[uint8]]
assert TestClass([1, 2, 3], [[uint8(200), uint8(25)], [uint8(25)]])
try:
TestClass([1, 2, 3], [[200, uint8(25)], [uint8(25)]])
assert False
except AssertionError:
pass
try:
TestClass([1, 2, 3], [uint8(200), uint8(25)])
assert False
except ValueError:
pass
def test_StrictDataClassOptional(self):
@strictdataclass
class TestClass:
a: Optional[int]
b: Optional[int]
c: Optional[Optional[int]]
d: Optional[Optional[int]]
good = TestClass(12, None, 13, None)
assert good
if __name__ == '__main__':
unittest.main()