Added tests, improved serialization
This commit is contained in:
parent
212d5da74a
commit
784c8b4d8a
|
@ -11,8 +11,7 @@ and decoded by src.util.cbor_serialization.
|
|||
filenames = ["src.protocols.farmer_protocol",
|
||||
"src.protocols.plotter_protocol"]
|
||||
|
||||
mods = [importlib.import_module(filename)
|
||||
for filename in filenames]
|
||||
mods = [importlib.import_module(filename) for filename in filenames]
|
||||
|
||||
custom_tags_separate = [dict([(cls, cls.__tag__) for _, cls in mod.__dict__.items()
|
||||
if hasattr(cls, "__tag__")]) for mod in mods]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from src.util.streamable import streamable, StreamableOptional
|
||||
from typing import Optional
|
||||
from src.util.streamable import streamable
|
||||
from src.types.block_header import BlockHeader
|
||||
from src.types.challenge import Challenge
|
||||
from src.types.proof_of_space import ProofOfSpace
|
||||
|
@ -8,7 +9,7 @@ from src.types.proof_of_time import ProofOfTime, ProofOfTimeOutput
|
|||
@streamable
|
||||
class FoliageBlock:
|
||||
proof_of_space: ProofOfSpace
|
||||
proof_of_time_output: StreamableOptional(ProofOfTimeOutput)
|
||||
proof_of_time: StreamableOptional(ProofOfTime)
|
||||
proof_of_time_output: Optional[ProofOfTimeOutput]
|
||||
proof_of_time: Optional[ProofOfTime]
|
||||
challenge: Challenge
|
||||
header: BlockHeader
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List
|
||||
from blspy import PublicKey
|
||||
from src.util.streamable import streamable, StreamableList
|
||||
from src.util.streamable import streamable
|
||||
from src.util.ints import uint8
|
||||
|
||||
|
||||
|
@ -8,4 +9,4 @@ class ProofOfSpace:
|
|||
pool_pubkey: PublicKey
|
||||
plot_pubkey: PublicKey
|
||||
size: uint8
|
||||
proof: StreamableList(uint8)
|
||||
proof: List[uint8]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from src.util.streamable import streamable, StreamableList
|
||||
from typing import List
|
||||
from src.util.streamable import streamable
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.types.classgroup import ClassgroupElement
|
||||
from src.util.ints import uint8, uint64
|
||||
|
@ -15,4 +16,4 @@ class ProofOfTimeOutput:
|
|||
class ProofOfTime:
|
||||
output: ProofOfTimeOutput
|
||||
witness_type: uint8
|
||||
witness: StreamableList(ClassgroupElement)
|
||||
witness: List[ClassgroupElement]
|
||||
|
|
|
@ -5,7 +5,7 @@ from .bin_methods import BinMethods
|
|||
|
||||
def make_sized_bytes(size):
|
||||
"""
|
||||
Create a streamable type that subclasses "hexbytes" but requires instances
|
||||
Create a streamable type that subclasses "bytes" but requires instances
|
||||
to be a certain, fixed size.
|
||||
"""
|
||||
name = "bytes%d" % size
|
||||
|
|
|
@ -29,7 +29,8 @@ def default_encoder(encoder, value: Any):
|
|||
|
||||
def tag_hook(decoder, tag, shareable_index=None):
|
||||
"""
|
||||
If we find a custom tag, decode this. Otherwise, just return the tag (no decoding).
|
||||
If we find a custom tag, decode this by calling the constructor with the
|
||||
member data items. Otherwise, just return the tag (no decoding).
|
||||
"""
|
||||
for (cls, cls_tag) in custom_tags.items():
|
||||
if tag.tag == cls_tag:
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import dataclasses
|
||||
from typing import Any
|
||||
from src.util.type_checking import ArgTypeChecker
|
||||
from typing import Any, Type
|
||||
from src.util.type_checking import strictdataclass
|
||||
|
||||
|
||||
def cbor_message(tag: int):
|
||||
|
@ -8,8 +7,7 @@ def cbor_message(tag: int):
|
|||
Decorator, converts a class into a data class, which checks all arguments to make sure
|
||||
they are the right type.
|
||||
"""
|
||||
def apply_cbor_code(cls: Any):
|
||||
cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True)
|
||||
return type(cls.__name__, (cls1, ArgTypeChecker), {'__tag__': tag})
|
||||
def apply_cbor_code(cls: Any) -> Type:
|
||||
cls1 = strictdataclass(cls=cls)
|
||||
return type(cls.__name__, (cls1,), {'__tag__': tag})
|
||||
return apply_cbor_code
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import dataclasses
|
||||
from __future__ import annotations
|
||||
from blspy import PublicKey, Signature, PrependSignature
|
||||
from typing import Type, BinaryIO, get_type_hints, Any, Optional, List
|
||||
from src.util.ints import uint32, uint8
|
||||
from src.util.type_checking import ArgTypeChecker
|
||||
from typing import Type, BinaryIO, get_type_hints, Any, List
|
||||
from src.util.type_checking import strictdataclass, is_type_List, is_type_SpecificOptional
|
||||
from src.util.bin_methods import BinMethods
|
||||
from src.util.ints import uint32
|
||||
|
||||
|
||||
# TODO: Remove hack, this allows streaming these objects from binary
|
||||
|
@ -16,134 +16,80 @@ size_hints = {
|
|||
|
||||
def streamable(cls: Any):
|
||||
"""
|
||||
This is a decorator for class definitions. It applies the dataclasses.dataclass
|
||||
decorator, and also allows fields to be cast to their expected type. The resulting
|
||||
class also gets parse and stream for free, as long as all its constituent elements
|
||||
have it.
|
||||
This is a decorator for class definitions. It applies the strictdataclass decorator,
|
||||
which checks all types at construction. It also defines a simple serialization format,
|
||||
and adds parse, from bytes, stream, and serialize methods.
|
||||
|
||||
Serialization format:
|
||||
* Each field is serialized in order, by calling parse/serialize.
|
||||
* For Lists, there is a 4 byte prefix for the list length.
|
||||
* For Optionals, there is a one byte prefix, 1 iff object is present, 0 iff not.
|
||||
|
||||
All of the constituents must have parse/from_bytes, and stream/serialize and therefore
|
||||
be of fixed size. For example, int cannot be a constituent since it is not a fixed size,
|
||||
whereas uint32 can be.
|
||||
|
||||
This class is used for deterministic serialization and hashing, for consensus critical
|
||||
objects such as the block header.
|
||||
"""
|
||||
|
||||
class _Local:
|
||||
@classmethod
|
||||
def parse_one_item(cls: Type[cls.__name__], f_type: Type, f: BinaryIO):
|
||||
if is_type_List(f_type):
|
||||
inner_type: Type = f_type.__args__[0]
|
||||
full_list: List[inner_type] = []
|
||||
assert inner_type != List.__args__[0]
|
||||
list_size: uint32 = int.from_bytes(f.read(4), "big")
|
||||
for list_index in range(list_size):
|
||||
full_list.append(cls.parse_one_item(inner_type, f))
|
||||
return full_list
|
||||
if is_type_SpecificOptional(f_type):
|
||||
inner_type: Type = f_type.__args__[0]
|
||||
is_present: bool = f.read(1) == bytes([1])
|
||||
if is_present:
|
||||
return cls.parse_one_item(inner_type, f)
|
||||
else:
|
||||
return None
|
||||
if hasattr(f_type, "parse"):
|
||||
return f_type.parse(f)
|
||||
if hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]:
|
||||
return f_type.from_bytes(f.read(size_hints[f_type.__name__]))
|
||||
else:
|
||||
raise RuntimeError(f"Type {f_type} does not have parse")
|
||||
|
||||
@classmethod
|
||||
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__:
|
||||
values = []
|
||||
for f_name, f_type in get_type_hints(cls).items():
|
||||
if hasattr(f_type, "parse"):
|
||||
values.append(f_type.parse(f))
|
||||
elif hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]:
|
||||
values.append(f_type.from_bytes(f.read(size_hints[f_type.__name__])))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
for _, f_type in get_type_hints(cls).items():
|
||||
values.append(cls.parse_one_item(f_type, f))
|
||||
return cls(*values)
|
||||
|
||||
def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
|
||||
if is_type_List(f_type):
|
||||
assert is_type_List(type(item))
|
||||
f.write(uint32(len(item)).to_bytes(4, "big"))
|
||||
inner_type: Type = f_type.__args__[0]
|
||||
assert inner_type != List.__args__[0]
|
||||
for element in item:
|
||||
self.stream_one_item(inner_type, element, f)
|
||||
elif is_type_SpecificOptional(f_type):
|
||||
inner_type: Type = f_type.__args__[0]
|
||||
if item is None:
|
||||
f.write(bytes([0]))
|
||||
else:
|
||||
f.write(bytes([1]))
|
||||
self.stream_one_item(inner_type, item, f)
|
||||
elif hasattr(f_type, "stream"):
|
||||
item.stream(f)
|
||||
elif hasattr(f_type, "serialize"):
|
||||
f.write(item.serialize())
|
||||
else:
|
||||
raise NotImplementedError(f"can't stream {item}, {f_type}")
|
||||
|
||||
def stream(self, f: BinaryIO) -> None:
|
||||
for f_name, f_type in get_type_hints(self).items():
|
||||
v = getattr(self, f_name)
|
||||
if hasattr(f_type, "stream"):
|
||||
v.stream(f)
|
||||
elif hasattr(f_type, "serialize"):
|
||||
f.write(v.serialize())
|
||||
else:
|
||||
raise NotImplementedError(f"can't stream {v}, {f_name}")
|
||||
self.stream_one_item(f_type, getattr(self, f_name), f)
|
||||
|
||||
cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True)
|
||||
return type(cls.__name__, (cls1, BinMethods, ArgTypeChecker, _Local), {})
|
||||
|
||||
|
||||
def StreamableList(the_type):
|
||||
"""
|
||||
This creates a streamable homogenous list of the given streamable object. It has
|
||||
a 32-bit unsigned prefix length, so lists are limited to a length of 2^32 - 1.
|
||||
"""
|
||||
|
||||
cls_name = "%sList" % the_type.__name__
|
||||
|
||||
def __init__(self, items: List[the_type]):
|
||||
self._items = tuple(items)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._items)
|
||||
|
||||
@classmethod
|
||||
def parse(cls: Type[cls_name], f: BinaryIO) -> cls_name:
|
||||
count = uint32.parse(f)
|
||||
items = []
|
||||
for _ in range(count):
|
||||
if hasattr(the_type, "parse"):
|
||||
items.append(the_type.parse(f))
|
||||
elif hasattr(the_type, "from_bytes") and size_hints[the_type.__name__]:
|
||||
items.append(the_type.from_bytes(f.read(size_hints[the_type.__name__])))
|
||||
else:
|
||||
raise ValueError("wrong type for %s" % the_type)
|
||||
return cls(items)
|
||||
|
||||
def stream(self, f: BinaryIO) -> None:
|
||||
count = uint32(len(self._items))
|
||||
count.stream(f)
|
||||
for item in self._items:
|
||||
if hasattr(type(item), "stream"):
|
||||
item.stream(f)
|
||||
elif hasattr(type(item), "serialize"):
|
||||
f.write(item.serialize())
|
||||
else:
|
||||
raise NotImplementedError(f"can't stream {type(item)}")
|
||||
|
||||
def __str__(self):
|
||||
return str(self._items)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self._items)
|
||||
|
||||
namespace = dict(
|
||||
__init__=__init__, __iter__=__iter__, parse=parse,
|
||||
stream=stream, __str__=__str__, __repr__=__repr__)
|
||||
streamable_list_type = type(cls_name, (BinMethods,), namespace)
|
||||
return streamable_list_type
|
||||
|
||||
|
||||
def StreamableOptional(the_type):
|
||||
"""
|
||||
This creates a streamable optional of the given streamable object. It has
|
||||
a 1 byte big-endian prefix which is equal to 1 if the element is there,
|
||||
and 0 if the element is not there.
|
||||
"""
|
||||
|
||||
cls_name = "%sOptional" % the_type.__name__
|
||||
|
||||
def __init__(self, item: Optional[the_type]):
|
||||
self._item = item
|
||||
|
||||
@classmethod
|
||||
def parse(cls: Type[cls_name], f: BinaryIO) -> cls_name:
|
||||
is_present: bool = (uint8.parse(f) == 1)
|
||||
item: Optional[the_type] = None
|
||||
if is_present:
|
||||
if hasattr(the_type, "parse"):
|
||||
item = the_type.parse(f)
|
||||
elif hasattr(the_type, "from_bytes") and size_hints[the_type.__name__]:
|
||||
item = the_type.from_bytes(f.read(size_hints[the_type.__name__]))
|
||||
else:
|
||||
raise ValueError("wrong type for %s" % the_type)
|
||||
return cls(item)
|
||||
|
||||
def stream(self, f: BinaryIO) -> None:
|
||||
is_present: uint8 = uint8(1) if self._item else uint8(0)
|
||||
is_present.stream(f)
|
||||
if is_present == 1:
|
||||
if hasattr(type(self._item), "stream"):
|
||||
self._item.stream(f)
|
||||
elif hasattr(type(self._item), "serialize"):
|
||||
f.write(self._item.serialize())
|
||||
else:
|
||||
raise NotImplementedError(f"can't stream {type(self._item)}")
|
||||
|
||||
def __str__(self):
|
||||
return str(self._item)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self._item)
|
||||
|
||||
namespace = dict(
|
||||
__init__=__init__, parse=parse,
|
||||
stream=stream, __str__=__str__, __repr__=__repr__)
|
||||
streamable_optional_type = type(cls_name, (BinMethods,), namespace)
|
||||
return streamable_optional_type
|
||||
cls1 = strictdataclass(cls)
|
||||
return type(cls.__name__, (cls1, BinMethods, _Local), {})
|
||||
|
|
|
@ -1,25 +1,59 @@
|
|||
from typing import Any, Type, get_type_hints
|
||||
from typing import Any, Type, get_type_hints, List, Union
|
||||
import dataclasses
|
||||
|
||||
|
||||
class ArgTypeChecker:
|
||||
def parse_item(self, a: Any, f_name: str, f_type: Type) -> Any:
|
||||
if hasattr(f_type, "__origin__") and f_type.__origin__ == list:
|
||||
return [self.parse_item(el, f_type.__args__[0].__name__, f_type.__args__[0]) for el in a]
|
||||
if not isinstance(a, f_type):
|
||||
try:
|
||||
a = f_type.from_bytes(a)
|
||||
except TypeError:
|
||||
a = f_type(a)
|
||||
if not isinstance(a, f_type):
|
||||
raise ValueError("wrong type for %s" % f_name)
|
||||
return a
|
||||
|
||||
def __init__(self, *args):
|
||||
fields = get_type_hints(self)
|
||||
la, lf = len(args), len(fields)
|
||||
if la != lf:
|
||||
raise ValueError("got %d and expected %d args" % (la, lf))
|
||||
for a, (f_name, f_type) in zip(args, fields.items()):
|
||||
object.__setattr__(self, f_name, self.parse_item(a, f_name, f_type))
|
||||
def is_type_List(f_type: Type) -> bool:
|
||||
return (hasattr(f_type, "__origin__") and f_type.__origin__ == list) or f_type == list
|
||||
|
||||
|
||||
def is_type_SpecificOptional(f_type) -> bool:
|
||||
"""
|
||||
Returns true for types such as Optional[T], but not Optional, or T.
|
||||
"""
|
||||
return (hasattr(f_type, "__origin__") and f_type.__origin__ == Union
|
||||
and f_type.__args__[1]() is None)
|
||||
|
||||
|
||||
def strictdataclass(cls: Any):
|
||||
class _Local():
|
||||
"""
|
||||
Dataclass where all fields must be type annotated, and type checking is performed
|
||||
at initialization, even recursively through Lists. Non-annotated fields are ignored.
|
||||
Also, for any fields which have a type with .from_bytes(bytes) or constructor(bytes),
|
||||
bytes can be passed in and the type can be constructed.
|
||||
"""
|
||||
def parse_item(self, item: Any, f_name: str, f_type: Type) -> Any:
|
||||
if is_type_List(f_type):
|
||||
collected_list: f_type = []
|
||||
inner_type: Type = f_type.__args__[0]
|
||||
assert inner_type != List.__args__[0]
|
||||
if not is_type_List(type(item)):
|
||||
raise ValueError(f"Wrong type for {f_name}, need a list.")
|
||||
for el in item:
|
||||
collected_list.append(self.parse_item(el, f_name, inner_type))
|
||||
return collected_list
|
||||
if is_type_SpecificOptional(f_type):
|
||||
if item is None:
|
||||
return None
|
||||
else:
|
||||
inner_type: Type = f_type.__args__[0]
|
||||
return self.parse_item(item, f_name, inner_type)
|
||||
if not isinstance(item, f_type):
|
||||
try:
|
||||
item = f_type(item)
|
||||
except (TypeError, AttributeError, ValueError):
|
||||
item = f_type.from_bytes(item)
|
||||
if not isinstance(item, f_type):
|
||||
raise ValueError(f"Wrong type for {f_name}")
|
||||
return item
|
||||
|
||||
def __init__(self, *args):
|
||||
fields = get_type_hints(self)
|
||||
la, lf = len(args), len(fields)
|
||||
if la != lf:
|
||||
raise ValueError("got %d and expected %d args" % (la, lf))
|
||||
for a, (f_name, f_type) in zip(args, fields.items()):
|
||||
object.__setattr__(self, f_name, self.parse_item(a, f_name, f_type))
|
||||
|
||||
cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True)
|
||||
return type(cls.__name__, (cls1, _Local), {})
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import unittest
|
||||
from typing import List, Optional
|
||||
from src.util.streamable import streamable
|
||||
from src.util.ints import uint32
|
||||
|
||||
|
||||
class TestStreamable(unittest.TestCase):
|
||||
def test_basic(self):
|
||||
@streamable
|
||||
class TestClass:
|
||||
a: uint32
|
||||
b: uint32
|
||||
c: List[uint32]
|
||||
d: List[List[uint32]]
|
||||
e: Optional[uint32]
|
||||
f: Optional[uint32]
|
||||
|
||||
a = TestClass(24, 352, [1, 2, 4], [[1, 2, 3], [3, 4]], 728, None)
|
||||
|
||||
b: bytes = a.serialize()
|
||||
assert a == TestClass.from_bytes(b)
|
||||
|
||||
def test_variablesize(self):
|
||||
@streamable
|
||||
class TestClass2:
|
||||
a: uint32
|
||||
b: uint32
|
||||
c: str
|
||||
|
||||
a = TestClass2(1, 2, "3")
|
||||
try:
|
||||
a.serialize()
|
||||
assert False
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
@streamable
|
||||
class TestClass3:
|
||||
a: int
|
||||
|
||||
b = TestClass3(1)
|
||||
try:
|
||||
b.serialize()
|
||||
assert False
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,89 @@
|
|||
import unittest
|
||||
from src.util.type_checking import is_type_List, is_type_SpecificOptional, strictdataclass
|
||||
from src.util.ints import uint8
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
|
||||
class TestIsTypeList(unittest.TestCase):
|
||||
def test_basic_list(self):
|
||||
a = [1, 2, 3]
|
||||
assert is_type_List(type(a))
|
||||
assert is_type_List(List)
|
||||
assert is_type_List(List[int])
|
||||
assert is_type_List(List[uint8])
|
||||
assert is_type_List(list)
|
||||
assert not is_type_List(Tuple)
|
||||
assert not is_type_List(tuple)
|
||||
assert not is_type_List(dict)
|
||||
|
||||
def test_not_lists(self):
|
||||
assert not is_type_List(Dict)
|
||||
|
||||
|
||||
class TestIsTypeSpecificOptional(unittest.TestCase):
|
||||
def test_basic_optional(self):
|
||||
assert is_type_SpecificOptional(Optional[int])
|
||||
assert is_type_SpecificOptional(Optional[Optional[int]])
|
||||
assert not is_type_SpecificOptional(List[int])
|
||||
|
||||
|
||||
class TestStrictClass(unittest.TestCase):
|
||||
def test_StrictDataClass(self):
|
||||
@strictdataclass
|
||||
class TestClass1:
|
||||
a: int
|
||||
b: str
|
||||
|
||||
good: TestClass1 = TestClass1(24, "!@12")
|
||||
assert TestClass1.__name__ == "TestClass1"
|
||||
assert good
|
||||
assert good.a == 24
|
||||
assert good.b == "!@12"
|
||||
good2 = TestClass1(52, bytes([1, 2, 3]))
|
||||
assert good2.b == str(bytes([1, 2, 3]))
|
||||
|
||||
def test_StrictDataClassBad(self):
|
||||
@strictdataclass
|
||||
class TestClass2:
|
||||
a: int
|
||||
b = 0
|
||||
|
||||
assert TestClass2(25)
|
||||
try:
|
||||
TestClass2(1, 2)
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_StrictDataClassLists(self):
|
||||
@strictdataclass
|
||||
class TestClass:
|
||||
a: List[int]
|
||||
b: List[List[uint8]]
|
||||
|
||||
assert TestClass([1, 2, 3], [[uint8(200), uint8(25)], [uint8(25)]])
|
||||
try:
|
||||
TestClass([1, 2, 3], [[200, uint8(25)], [uint8(25)]])
|
||||
assert False
|
||||
except AssertionError:
|
||||
pass
|
||||
try:
|
||||
TestClass([1, 2, 3], [uint8(200), uint8(25)])
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_StrictDataClassOptional(self):
|
||||
@strictdataclass
|
||||
class TestClass:
|
||||
a: Optional[int]
|
||||
b: Optional[int]
|
||||
c: Optional[Optional[int]]
|
||||
d: Optional[Optional[int]]
|
||||
|
||||
good = TestClass(12, None, 13, None)
|
||||
assert good
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue