Add new_block_height to FeeEstimatorInterface (#14277)

* Add new_block_height to FeeEstimatorInterface

* Fancy-up create_test_block_record, rename test

* Integrate new interface method
This commit is contained in:
Adam Kelly 2023-01-05 13:34:53 -08:00 committed by GitHub
parent e1c986435a
commit 820493ffa8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 12 deletions

View file

@ -20,13 +20,18 @@ class BitcoinFeeEstimator(FeeEstimatorInterface):
fee_rate_estimator: SmartFeeEstimator
tracker: FeeTracker
last_mempool_info: FeeMempoolInfo
block_height: uint32
def __init__(self, fee_tracker: FeeTracker, smart_fee_estimator: SmartFeeEstimator) -> None:
self.fee_rate_estimator: SmartFeeEstimator = smart_fee_estimator
self.tracker: FeeTracker = fee_tracker
self.last_mempool_info: FeeMempoolInfo = EmptyFeeMempoolInfo
def new_block_height(self, block_height: uint32) -> None:
self.block_height = block_height
def new_block(self, block_info: FeeBlockInfo) -> None:
self.block_height = block_info.block_height
self.tracker.process_block(block_info.block_height, block_info.included_items)
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:

View file

@ -6,11 +6,16 @@ from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRate
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint32
class FeeEstimatorInterface(Protocol):
def new_block_height(self, block_height: uint32) -> None:
"""Called immediately when block height changes. Can be called multiple times before `new_block`"""
pass
def new_block(self, block_info: FeeBlockInfo) -> None:
"""A new block has been added to the blockchain"""
"""A new transaction block has been added to the blockchain"""
pass
def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:

View file

@ -567,6 +567,7 @@ class MempoolManager:
if self.peak == new_peak:
return []
assert new_peak.timestamp is not None
self.fee_estimator.new_block_height(new_peak.height)
included_items = []
use_optimization: bool = self.peak is not None and new_peak.prev_transaction_block_hash == self.peak.header_hash

View file

@ -31,14 +31,11 @@ async def zero_calls_get_coin_record(_: bytes32) -> Optional[CoinRecord]:
assert False
async def instantiate_mempool_manager(
get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]]
) -> MempoolManager:
mempool_manager = MempoolManager(get_coin_record, DEFAULT_CONSTANTS)
test_block_record = BlockRecord(
def create_test_block_record(*, height: uint32 = TEST_HEIGHT) -> BlockRecord:
return BlockRecord(
IDENTITY_PUZZLE_HASH,
IDENTITY_PUZZLE_HASH,
TEST_HEIGHT,
height,
uint128(0),
uint128(0),
uint8(0),
@ -52,7 +49,7 @@ async def instantiate_mempool_manager(
uint64(0),
uint8(0),
False,
uint32(TEST_HEIGHT - 1),
uint32(height - 1),
TEST_TIMESTAMP,
None,
uint64(0),
@ -62,6 +59,13 @@ async def instantiate_mempool_manager(
None,
None,
)
async def instantiate_mempool_manager(
get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]]
) -> MempoolManager:
mempool_manager = MempoolManager(get_coin_record, DEFAULT_CONSTANTS)
test_block_record = create_test_block_record()
await mempool_manager.new_peak(test_block_record, None)
return mempool_manager

View file

@ -1,7 +1,9 @@
from __future__ import annotations
from typing import Dict
import types
from typing import Dict, List
import pytest
from chia_rs import Coin
from chia.consensus.cost_calculator import NPCResult
@ -21,6 +23,11 @@ from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRate
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint32, uint64
from tests.core.mempool.test_mempool_manager import (
create_test_block_record,
instantiate_mempool_manager,
zero_calls_get_coin_record,
)
def make_mempoolitem() -> MempoolItem:
@ -47,12 +54,18 @@ def make_mempoolitem() -> MempoolItem:
class FeeEstimatorInterfaceIntegrationVerificationObject(FeeEstimatorInterface):
add_mempool_item_called_count: int = 0
remove_mempool_item_called_count: int = 0
new_block_called_count: int = 0
current_block_height: int = 0
def new_block_height(self, block_height: uint32) -> None:
self.current_block_height: int = block_height
def new_block(self, block_info: FeeBlockInfo) -> None:
"""A new block has been added to the blockchain"""
pass
self.current_block_height = block_info.block_height
self.new_block_called_count += 1
def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
"""A MempoolItem (transaction and associated info) has been added to the mempool"""
self.add_mempool_item_called_count += 1
@ -124,4 +137,101 @@ def test_mempool_fee_estimator_remove_item() -> None:
def test_mempool_manager_fee_estimator_new_block() -> None:
pass
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
item = make_mempoolitem()
height = uint32(4)
included_items = [item]
mempool.fee_estimator.new_block(FeeBlockInfo(height, included_items))
assert mempool.fee_estimator.new_block_called_count == 1 # type: ignore[attr-defined]
def test_current_block_height_init() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
assert mempool.fee_estimator.current_block_height == uint32(0) # type: ignore[attr-defined]
def test_current_block_height_add() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
item = make_mempoolitem()
height = uint32(7)
fee_estimator.new_block_height(height)
mempool.add_to_pool(item)
assert mempool.fee_estimator.current_block_height == height # type: ignore[attr-defined]
def test_current_block_height_remove() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
item = make_mempoolitem()
height = uint32(8)
fee_estimator.new_block_height(height)
mempool.add_to_pool(item)
mempool.remove_from_pool([item.name], MempoolRemoveReason.CONFLICT)
assert mempool.fee_estimator.current_block_height == height # type: ignore[attr-defined]
def test_current_block_height_new_block_height() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
height = uint32(9)
mempool.fee_estimator.new_block_height(height)
assert mempool.fee_estimator.current_block_height == height # type: ignore[attr-defined]
def test_current_block_height_new_block() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
height = uint32(10)
included_items: List[MempoolItem] = []
mempool.fee_estimator.new_block(FeeBlockInfo(height, included_items))
assert mempool.fee_estimator.current_block_height == height # type: ignore[attr-defined]
def test_current_block_height_new_height_then_new_block() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
height = uint32(11)
included_items: List[MempoolItem] = []
fee_estimator.new_block_height(uint32(height - 1))
mempool.fee_estimator.new_block(FeeBlockInfo(height, included_items))
assert mempool.fee_estimator.current_block_height == height # type: ignore[attr-defined]
def test_current_block_height_new_block_then_new_height() -> None:
fee_estimator = FeeEstimatorInterfaceIntegrationVerificationObject()
mempool = Mempool(test_mempool_info, fee_estimator)
height = uint32(12)
included_items: List[MempoolItem] = []
fee_estimator.new_block_height(uint32(height - 1))
mempool.fee_estimator.new_block(FeeBlockInfo(height, included_items))
fee_estimator.new_block_height(uint32(height + 1))
assert mempool.fee_estimator.current_block_height == height + 1 # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_mm_new_peak_changes_fee_estimator_block_height() -> None:
mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record)
block2 = create_test_block_record(height=uint32(2))
await mempool_manager.new_peak(block2, None)
assert mempool_manager.mempool.fee_estimator.block_height == uint32(2) # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_mm_calls_new_block_height() -> None:
mempool_manager = await instantiate_mempool_manager(zero_calls_get_coin_record)
new_block_height_called = False
def test_new_block_height_called(self: FeeEstimatorInterface, height: uint32) -> None:
nonlocal new_block_height_called
new_block_height_called = True
# Replace new_block_height with test function
mempool_manager.fee_estimator.new_block_height = types.MethodType( # type: ignore[assignment]
test_new_block_height_called, mempool_manager.fee_estimator
)
block2 = create_test_block_record(height=uint32(2))
await mempool_manager.new_peak(block2, None)
assert new_block_height_called