util: Improve `list_to_batches` (#15415)

* Make `list_to_batches` work with collections instead of lists only

* Move `to_batches` into `chia.util.misc`

* Only support `set` and `list`

* Drop `tests.generator.test_to_batches` exclusion

* Improve type restrictions and be more coverage friendly in tests
This commit is contained in:
dustinface 2023-06-06 19:00:07 +02:00 committed by GitHub
parent 5ad78bd669
commit ebc5f3c124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 95 additions and 61 deletions

View File

@ -26,8 +26,8 @@ from chia.protocols.harvester_protocol import (
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import NodeType, make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.util.generator_tools import list_to_batches
from chia.util.ints import int16, uint32, uint64
from chia.util.misc import to_batches
log = logging.getLogger(__name__)
@ -150,10 +150,10 @@ class Sender:
self._messages.clear()
if self._task is not None:
self.sync_start(self._plot_manager.plot_count(), True)
for remaining, batch in list_to_batches(
for batch in to_batches(
list(self._plot_manager.plots.values()), self._plot_manager.refresh_parameter.batch_size
):
self.process_batch(batch, remaining)
self.process_batch(batch.entries, batch.remaining)
self.sync_done([], 0)
async def _wait_for_response(self) -> bool:
@ -251,8 +251,8 @@ class Sender:
if len(data) == 0:
self._add_message(message_type, payload_type, [], True)
return
for remaining, batch in list_to_batches(data, self._plot_manager.refresh_parameter.batch_size):
self._add_message(message_type, payload_type, batch, remaining == 0)
for batch in to_batches(data, self._plot_manager.refresh_parameter.batch_size):
self._add_message(message_type, payload_type, batch.entries, batch.remaining == 0)
def sync_start(self, count: float, initial: bool) -> None:
log.debug(f"sync_start {self}: count {count}, initial {initial}")

View File

@ -14,7 +14,7 @@ from chiapos import DiskProver
from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size
from chia.plotting.cache import Cache, CacheEntry
from chia.plotting.util import PlotInfo, PlotRefreshEvents, PlotRefreshResult, PlotsRefreshParameter, get_plot_filenames
from chia.util.generator_tools import list_to_batches
from chia.util.misc import to_batches
log = logging.getLogger(__name__)
@ -180,19 +180,19 @@ class PlotManager:
for filename in filenames_to_remove:
del self.plot_filename_paths[filename]
for remaining, batch in list_to_batches(sorted(list(plot_paths)), self.refresh_parameter.batch_size):
batch_result: PlotRefreshResult = self.refresh_batch(batch, plot_directories)
for batch in to_batches(sorted(list(plot_paths)), self.refresh_parameter.batch_size):
batch_result: PlotRefreshResult = self.refresh_batch(batch.entries, plot_directories)
if not self._refreshing_enabled:
self.log.debug("refresh_plots: Aborted")
break
# Set the remaining files since `refresh_batch()` doesn't know them but we want to report it
batch_result.remaining = remaining
batch_result.remaining = batch.remaining
total_result.loaded += batch_result.loaded
total_result.processed += batch_result.processed
total_result.duration += batch_result.duration
self._refresh_callback(PlotRefreshEvents.batch_processed, batch_result)
if remaining == 0:
if batch.remaining == 0:
break
batch_sleep = self.refresh_parameter.batch_sleep_milliseconds
self.log.debug(f"refresh_plots: Sleep {batch_sleep} milliseconds")

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Iterator, List, Optional, Tuple
from typing import List, Optional, Tuple
from chiabip158 import PyBIP158
@ -71,14 +71,3 @@ def tx_removals_and_additions(results: Optional[SpendBundleConditions]) -> Tuple
additions.append(Coin(bytes32(spend.coin_id), bytes32(puzzle_hash), uint64(amount)))
return removals, additions
def list_to_batches(list_to_split: List[Any], batch_size: int) -> Iterator[Tuple[int, List[Any]]]:
if batch_size <= 0:
raise ValueError("list_to_batches: batch_size must be greater than 0.")
total_size = len(list_to_split)
if total_size == 0:
return iter(())
for batch_start in range(0, total_size, batch_size):
batch_end = min(batch_start + batch_size, total_size)
yield total_size - batch_end, list_to_split[batch_start:batch_end]

View File

@ -3,13 +3,16 @@ from __future__ import annotations
import dataclasses
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Sequence, Union
from typing import Any, Collection, Dict, Generic, Iterator, List, Sequence, TypeVar, Union
from chia.util.errors import InvalidPathError
from chia.util.ints import uint16, uint32, uint64
from chia.util.streamable import Streamable, recurse_jsonify, streamable
T = TypeVar("T")
@streamable
@dataclasses.dataclass(frozen=True)
@ -129,3 +132,36 @@ class UInt32Range(Streamable):
class UInt64Range(Streamable):
start: uint64 = uint64(0)
stop: uint64 = uint64(uint64.MAXIMUM_EXCLUSIVE - 1)
@dataclass(frozen=True)
class Batch(Generic[T]):
remaining: int
entries: List[T]
def to_batches(to_split: Collection[T], batch_size: int) -> Iterator[Batch[T]]:
if batch_size <= 0:
raise ValueError("to_batches: batch_size must be greater than 0.")
total_size = len(to_split)
if total_size == 0:
return iter(())
if isinstance(to_split, list):
for batch_start in range(0, total_size, batch_size):
batch_end = min(batch_start + batch_size, total_size)
yield Batch(total_size - batch_end, to_split[batch_start:batch_end])
elif isinstance(to_split, set):
processed = 0
entries = []
for entry in to_split:
entries.append(entry)
if len(entries) >= batch_size:
processed += len(entries)
yield Batch(total_size - processed, entries)
entries = []
if len(entries) > 0:
processed += len(entries)
yield Batch(total_size - processed, entries)
else:
raise ValueError(f"to_batches: Unsupported type {type(to_split)}")

View File

@ -138,7 +138,6 @@ tests.core.util.test_keychain
tests.core.util.test_keyring_wrapper
tests.core.util.test_lru_cache
tests.core.util.test_significant_bits
tests.generator.test_list_to_batches
tests.generator.test_scan
tests.plotting.test_plot_manager
tests.pools.test_pool_cmdline

View File

@ -1,33 +0,0 @@
from __future__ import annotations
import pytest
from chia.util.generator_tools import list_to_batches
def test_empty_lists():
# An empty list should return an empty iterator and skip the loop's body.
for _, _ in list_to_batches([], 1):
assert False
def test_valid():
for k in range(1, 10):
test_list = [x for x in range(0, k)]
for i in range(1, len(test_list) + 1): # Test batch_size 1 to 11 (length + 1)
checked = 0
for remaining, batch in list_to_batches(test_list, i):
assert remaining == max(len(test_list) - checked - i, 0)
assert len(batch) <= i
assert batch == test_list[checked : min(checked + i, len(test_list))]
checked += len(batch)
assert checked == len(test_list)
def test_invalid_batch_sizes():
with pytest.raises(ValueError):
for _ in list_to_batches([], 0):
assert False
with pytest.raises(ValueError):
for _ in list_to_batches([], -1):
assert False

View File

@ -28,8 +28,8 @@ from chia.server.ws_connection import WSChiaConnection
from chia.simulator.block_tools import BlockTools
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.generator_tools import list_to_batches
from chia.util.ints import int16, uint64
from chia.util.misc import to_batches
from tests.plot_sync.util import start_harvester_service
log = logging.getLogger(__name__)
@ -107,8 +107,8 @@ class TestData:
sync_id = self.plot_sync_sender._sync_id
if len(loaded) == 0:
self.harvester.plot_sync_sender.process_batch([], 0)
for remaining, batch in list_to_batches(loaded, batch_size):
self.harvester.plot_sync_sender.process_batch(batch, remaining)
for batch in to_batches(loaded, batch_size):
self.harvester.plot_sync_sender.process_batch(batch.entries, batch.remaining)
self.harvester.plot_sync_sender.sync_done(removed_paths, 0)
await self.event_loop.run_in_executor(None, run_internal)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
from typing import List
import pytest
from chia.util.errors import InvalidPathError
from chia.util.misc import format_bytes, format_minutes, validate_directory_writable
from chia.util.misc import format_bytes, format_minutes, to_batches, validate_directory_writable
class TestMisc:
@ -70,3 +72,44 @@ def test_validate_directory_writable(tmp_path) -> None:
with pytest.raises(InvalidPathError, match="Directory not writable") as exc_info:
validate_directory_writable(tmp_path)
assert exc_info.value.path == tmp_path
def test_empty_lists() -> None:
# An empty list should return an empty iterator and skip the loop's body.
empty: List[int] = []
with pytest.raises(StopIteration):
next(to_batches(empty, 1))
@pytest.mark.parametrize("collection_type", [list, set])
def test_valid(collection_type: type) -> None:
for k in range(1, 10):
test_collection = collection_type([x for x in range(0, k)])
for i in range(1, len(test_collection) + 1): # Test batch_size 1 to 11 (length + 1)
checked = 0
for batch in to_batches(test_collection, i):
assert batch.remaining == max(len(test_collection) - checked - i, 0)
assert len(batch.entries) <= i
entries = []
for j, entry in enumerate(test_collection):
if j < checked:
continue
if j >= min(checked + i, len(test_collection)):
break
entries.append(entry)
assert batch.entries == entries
checked += len(batch.entries)
assert checked == len(test_collection)
def test_invalid_batch_sizes() -> None:
with pytest.raises(ValueError):
next(to_batches([], 0))
with pytest.raises(ValueError):
next(to_batches([], -1))
def test_invalid_input_type() -> None:
with pytest.raises(ValueError, match="Unsupported type"):
next(to_batches(dict({1: 2}), 1))