util: Improve `list_to_batches` (#15415)
* Make `list_to_batches` work with collections instead of lists only * Move `to_batches` into `chia.util.misc` * Only support `set` and `list` * Drop `tests.generator.test_to_batches` exclusion * Improve type restrictions and be more coverage friendly in tests
This commit is contained in:
parent
5ad78bd669
commit
ebc5f3c124
|
@ -26,8 +26,8 @@ from chia.protocols.harvester_protocol import (
|
|||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import NodeType, make_msg
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.util.generator_tools import list_to_batches
|
||||
from chia.util.ints import int16, uint32, uint64
|
||||
from chia.util.misc import to_batches
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -150,10 +150,10 @@ class Sender:
|
|||
self._messages.clear()
|
||||
if self._task is not None:
|
||||
self.sync_start(self._plot_manager.plot_count(), True)
|
||||
for remaining, batch in list_to_batches(
|
||||
for batch in to_batches(
|
||||
list(self._plot_manager.plots.values()), self._plot_manager.refresh_parameter.batch_size
|
||||
):
|
||||
self.process_batch(batch, remaining)
|
||||
self.process_batch(batch.entries, batch.remaining)
|
||||
self.sync_done([], 0)
|
||||
|
||||
async def _wait_for_response(self) -> bool:
|
||||
|
@ -251,8 +251,8 @@ class Sender:
|
|||
if len(data) == 0:
|
||||
self._add_message(message_type, payload_type, [], True)
|
||||
return
|
||||
for remaining, batch in list_to_batches(data, self._plot_manager.refresh_parameter.batch_size):
|
||||
self._add_message(message_type, payload_type, batch, remaining == 0)
|
||||
for batch in to_batches(data, self._plot_manager.refresh_parameter.batch_size):
|
||||
self._add_message(message_type, payload_type, batch.entries, batch.remaining == 0)
|
||||
|
||||
def sync_start(self, count: float, initial: bool) -> None:
|
||||
log.debug(f"sync_start {self}: count {count}, initial {initial}")
|
||||
|
|
|
@ -14,7 +14,7 @@ from chiapos import DiskProver
|
|||
from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size
|
||||
from chia.plotting.cache import Cache, CacheEntry
|
||||
from chia.plotting.util import PlotInfo, PlotRefreshEvents, PlotRefreshResult, PlotsRefreshParameter, get_plot_filenames
|
||||
from chia.util.generator_tools import list_to_batches
|
||||
from chia.util.misc import to_batches
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -180,19 +180,19 @@ class PlotManager:
|
|||
for filename in filenames_to_remove:
|
||||
del self.plot_filename_paths[filename]
|
||||
|
||||
for remaining, batch in list_to_batches(sorted(list(plot_paths)), self.refresh_parameter.batch_size):
|
||||
batch_result: PlotRefreshResult = self.refresh_batch(batch, plot_directories)
|
||||
for batch in to_batches(sorted(list(plot_paths)), self.refresh_parameter.batch_size):
|
||||
batch_result: PlotRefreshResult = self.refresh_batch(batch.entries, plot_directories)
|
||||
if not self._refreshing_enabled:
|
||||
self.log.debug("refresh_plots: Aborted")
|
||||
break
|
||||
# Set the remaining files since `refresh_batch()` doesn't know them but we want to report it
|
||||
batch_result.remaining = remaining
|
||||
batch_result.remaining = batch.remaining
|
||||
total_result.loaded += batch_result.loaded
|
||||
total_result.processed += batch_result.processed
|
||||
total_result.duration += batch_result.duration
|
||||
|
||||
self._refresh_callback(PlotRefreshEvents.batch_processed, batch_result)
|
||||
if remaining == 0:
|
||||
if batch.remaining == 0:
|
||||
break
|
||||
batch_sleep = self.refresh_parameter.batch_sleep_milliseconds
|
||||
self.log.debug(f"refresh_plots: Sleep {batch_sleep} milliseconds")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from chiabip158 import PyBIP158
|
||||
|
||||
|
@ -71,14 +71,3 @@ def tx_removals_and_additions(results: Optional[SpendBundleConditions]) -> Tuple
|
|||
additions.append(Coin(bytes32(spend.coin_id), bytes32(puzzle_hash), uint64(amount)))
|
||||
|
||||
return removals, additions
|
||||
|
||||
|
||||
def list_to_batches(list_to_split: List[Any], batch_size: int) -> Iterator[Tuple[int, List[Any]]]:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("list_to_batches: batch_size must be greater than 0.")
|
||||
total_size = len(list_to_split)
|
||||
if total_size == 0:
|
||||
return iter(())
|
||||
for batch_start in range(0, total_size, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_size)
|
||||
yield total_size - batch_end, list_to_split[batch_start:batch_end]
|
||||
|
|
|
@ -3,13 +3,16 @@ from __future__ import annotations
|
|||
import dataclasses
|
||||
import signal
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Sequence, Union
|
||||
from typing import Any, Collection, Dict, Generic, Iterator, List, Sequence, TypeVar, Union
|
||||
|
||||
from chia.util.errors import InvalidPathError
|
||||
from chia.util.ints import uint16, uint32, uint64
|
||||
from chia.util.streamable import Streamable, recurse_jsonify, streamable
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@streamable
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -129,3 +132,36 @@ class UInt32Range(Streamable):
|
|||
class UInt64Range(Streamable):
|
||||
start: uint64 = uint64(0)
|
||||
stop: uint64 = uint64(uint64.MAXIMUM_EXCLUSIVE - 1)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Batch(Generic[T]):
|
||||
remaining: int
|
||||
entries: List[T]
|
||||
|
||||
|
||||
def to_batches(to_split: Collection[T], batch_size: int) -> Iterator[Batch[T]]:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("to_batches: batch_size must be greater than 0.")
|
||||
total_size = len(to_split)
|
||||
if total_size == 0:
|
||||
return iter(())
|
||||
|
||||
if isinstance(to_split, list):
|
||||
for batch_start in range(0, total_size, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_size)
|
||||
yield Batch(total_size - batch_end, to_split[batch_start:batch_end])
|
||||
elif isinstance(to_split, set):
|
||||
processed = 0
|
||||
entries = []
|
||||
for entry in to_split:
|
||||
entries.append(entry)
|
||||
if len(entries) >= batch_size:
|
||||
processed += len(entries)
|
||||
yield Batch(total_size - processed, entries)
|
||||
entries = []
|
||||
if len(entries) > 0:
|
||||
processed += len(entries)
|
||||
yield Batch(total_size - processed, entries)
|
||||
else:
|
||||
raise ValueError(f"to_batches: Unsupported type {type(to_split)}")
|
||||
|
|
|
@ -138,7 +138,6 @@ tests.core.util.test_keychain
|
|||
tests.core.util.test_keyring_wrapper
|
||||
tests.core.util.test_lru_cache
|
||||
tests.core.util.test_significant_bits
|
||||
tests.generator.test_list_to_batches
|
||||
tests.generator.test_scan
|
||||
tests.plotting.test_plot_manager
|
||||
tests.pools.test_pool_cmdline
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from chia.util.generator_tools import list_to_batches
|
||||
|
||||
|
||||
def test_empty_lists():
|
||||
# An empty list should return an empty iterator and skip the loop's body.
|
||||
for _, _ in list_to_batches([], 1):
|
||||
assert False
|
||||
|
||||
|
||||
def test_valid():
|
||||
for k in range(1, 10):
|
||||
test_list = [x for x in range(0, k)]
|
||||
for i in range(1, len(test_list) + 1): # Test batch_size 1 to 11 (length + 1)
|
||||
checked = 0
|
||||
for remaining, batch in list_to_batches(test_list, i):
|
||||
assert remaining == max(len(test_list) - checked - i, 0)
|
||||
assert len(batch) <= i
|
||||
assert batch == test_list[checked : min(checked + i, len(test_list))]
|
||||
checked += len(batch)
|
||||
assert checked == len(test_list)
|
||||
|
||||
|
||||
def test_invalid_batch_sizes():
|
||||
with pytest.raises(ValueError):
|
||||
for _ in list_to_batches([], 0):
|
||||
assert False
|
||||
with pytest.raises(ValueError):
|
||||
for _ in list_to_batches([], -1):
|
||||
assert False
|
|
@ -28,8 +28,8 @@ from chia.server.ws_connection import WSChiaConnection
|
|||
from chia.simulator.block_tools import BlockTools
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.util.generator_tools import list_to_batches
|
||||
from chia.util.ints import int16, uint64
|
||||
from chia.util.misc import to_batches
|
||||
from tests.plot_sync.util import start_harvester_service
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -107,8 +107,8 @@ class TestData:
|
|||
sync_id = self.plot_sync_sender._sync_id
|
||||
if len(loaded) == 0:
|
||||
self.harvester.plot_sync_sender.process_batch([], 0)
|
||||
for remaining, batch in list_to_batches(loaded, batch_size):
|
||||
self.harvester.plot_sync_sender.process_batch(batch, remaining)
|
||||
for batch in to_batches(loaded, batch_size):
|
||||
self.harvester.plot_sync_sender.process_batch(batch.entries, batch.remaining)
|
||||
self.harvester.plot_sync_sender.sync_done(removed_paths, 0)
|
||||
|
||||
await self.event_loop.run_in_executor(None, run_internal)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from chia.util.errors import InvalidPathError
|
||||
from chia.util.misc import format_bytes, format_minutes, validate_directory_writable
|
||||
from chia.util.misc import format_bytes, format_minutes, to_batches, validate_directory_writable
|
||||
|
||||
|
||||
class TestMisc:
|
||||
|
@ -70,3 +72,44 @@ def test_validate_directory_writable(tmp_path) -> None:
|
|||
with pytest.raises(InvalidPathError, match="Directory not writable") as exc_info:
|
||||
validate_directory_writable(tmp_path)
|
||||
assert exc_info.value.path == tmp_path
|
||||
|
||||
|
||||
def test_empty_lists() -> None:
|
||||
# An empty list should return an empty iterator and skip the loop's body.
|
||||
empty: List[int] = []
|
||||
with pytest.raises(StopIteration):
|
||||
next(to_batches(empty, 1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("collection_type", [list, set])
|
||||
def test_valid(collection_type: type) -> None:
|
||||
for k in range(1, 10):
|
||||
test_collection = collection_type([x for x in range(0, k)])
|
||||
for i in range(1, len(test_collection) + 1): # Test batch_size 1 to 11 (length + 1)
|
||||
checked = 0
|
||||
for batch in to_batches(test_collection, i):
|
||||
assert batch.remaining == max(len(test_collection) - checked - i, 0)
|
||||
assert len(batch.entries) <= i
|
||||
entries = []
|
||||
for j, entry in enumerate(test_collection):
|
||||
if j < checked:
|
||||
continue
|
||||
if j >= min(checked + i, len(test_collection)):
|
||||
break
|
||||
entries.append(entry)
|
||||
assert batch.entries == entries
|
||||
checked += len(batch.entries)
|
||||
assert checked == len(test_collection)
|
||||
|
||||
|
||||
def test_invalid_batch_sizes() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
next(to_batches([], 0))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
next(to_batches([], -1))
|
||||
|
||||
|
||||
def test_invalid_input_type() -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported type"):
|
||||
next(to_batches(dict({1: 2}), 1))
|
||||
|
|
Loading…
Reference in New Issue