diff --git a/chia/full_node/fee_estimator.py b/chia/full_node/fee_estimator.py index f10c6465b0..22c306693e 100644 --- a/chia/full_node/fee_estimator.py +++ b/chia/full_node/fee_estimator.py @@ -39,7 +39,7 @@ class SmartFeeEstimator: # get_bucket_index returns left (-1) bucket (-1). Start value is already -1 # We want +1 from the lowest bucket it failed at. Thus +3 max_val = len(self.fee_tracker.buckets) - 1 - start_index = min(get_bucket_index(self.fee_tracker.sorted_buckets, fail_bucket.start) + 3, max_val) + start_index = min(get_bucket_index(self.fee_tracker.buckets, fail_bucket.start) + 3, max_val) fee_val: float = self.fee_tracker.buckets[start_index] return fee_val diff --git a/chia/full_node/fee_estimator_constants.py b/chia/full_node/fee_estimator_constants.py index b5e484c576..d0c7cf2315 100644 --- a/chia/full_node/fee_estimator_constants.py +++ b/chia/full_node/fee_estimator_constants.py @@ -2,9 +2,9 @@ from __future__ import annotations MIN_FEE_RATE = 0 # Value of first bucket -INITIAL_STEP = 5 # First bucket after zero value -MAX_FEE_RATE = 40000000 # Mojo per 1000 cost unit -INFINITE_FEE_RATE = 1000000000 +INITIAL_STEP = 5.0 # First bucket after zero value +MAX_FEE_RATE = 40000000.0 # Mojo per 1000 cost unit +INFINITE_FEE_RATE = 1000000000.0 STEP_SIZE = 1.05 # bucket increase by 1.05 diff --git a/chia/full_node/fee_tracker.py b/chia/full_node/fee_tracker.py index 811e6ada47..58c90e0cfc 100644 --- a/chia/full_node/fee_tracker.py +++ b/chia/full_node/fee_tracker.py @@ -1,11 +1,10 @@ from __future__ import annotations import logging +from bisect import bisect_left from dataclasses import dataclass from typing import List, Optional, Tuple -from sortedcontainers import SortedDict - from chia.full_node.fee_estimate_store import FeeStore from chia.full_node.fee_estimator_constants import ( FEE_ESTIMATOR_VERSION, @@ -61,21 +60,10 @@ def get_estimate_time_intervals() -> List[uint64]: return [uint64(blocks * SECONDS_PER_BLOCK) for blocks in get_estimate_block_intervals()] -def get_bucket_index(sorted_buckets: SortedDict, fee_rate: float) -> int: - if fee_rate in sorted_buckets: - bucket_index = sorted_buckets[fee_rate] - else: - # Choose the bucket to the left if we do not have exactly this fee rate - bucket_index = sorted_buckets.bisect_left(fee_rate) - 1 - - return int(bucket_index) - - # Implementation of bitcoin core fee estimation algorithm # https://gist.github.com/morcos/d3637f015bc4e607e1fd10d8351e9f41 class FeeStat: # TxConfirmStats - buckets: List[float] - sorted_buckets: SortedDict # key is upper bound of bucket, val is index in buckets + buckets: List[float] # These elements represent the upper-bound of the range for the bucket # For each bucket xL # Count the total number of txs in each bucket @@ -111,7 +99,6 @@ class FeeStat: # TxConfirmStats def __init__( self, buckets: List[float], - sorted_buckets: SortedDict, max_periods: int, decay: float, scale: int, @@ -119,7 +106,6 @@ class FeeStat: # TxConfirmStats my_type: str, ): self.buckets = buckets - self.sorted_buckets = sorted_buckets self.confirmed_average = [[] for _ in range(0, max_periods)] self.failed_average = [[] for _ in range(0, max_periods)] self.decay = decay @@ -150,7 +136,7 @@ class FeeStat: # TxConfirmStats periods_to_confirm = int((blocks_to_confirm + self.scale - 1) / self.scale) fee_rate = item.fee_per_cost * 1000 - bucket_index = get_bucket_index(self.sorted_buckets, fee_rate) + bucket_index = get_bucket_index(self.buckets, fee_rate) for i in range(periods_to_confirm, len(self.confirmed_average)): self.confirmed_average[i - 1][bucket_index] += 1 @@ -173,7 +159,7 @@ class FeeStat: # TxConfirmStats self.unconfirmed_txs[block_height % len(self.unconfirmed_txs)][i] = 0 def new_mempool_tx(self, block_height: uint32, fee_rate: float) -> int: - bucket_index: int = get_bucket_index(self.sorted_buckets, fee_rate) + bucket_index: int = get_bucket_index(self.buckets, fee_rate) block_index = block_height % len(self.unconfirmed_txs) self.unconfirmed_txs[block_index][bucket_index] += 1 return bucket_index @@ -400,8 +386,32 @@ class FeeStat: # TxConfirmStats return result +def clamp(n: int, smallest: int, largest: int) -> int: + return max(smallest, min(n, largest)) + + +def get_bucket_index(buckets: List[float], fee_rate: float) -> int: + if len(buckets) < 1: + raise RuntimeError("get_bucket_index: buckets is invalid ({buckets})") + # Choose the bucket to the left if we do not have exactly this fee rate + # Python's list.bisect_left returns the index to insert a new element into a sorted list + bucket_index = bisect_left(buckets, fee_rate) - 1 + return clamp(bucket_index, 0, len(buckets) - 1) + + +def init_buckets() -> List[float]: + fee_rate = INITIAL_STEP + + buckets: List[float] = [] + while fee_rate < MAX_FEE_RATE: + buckets.append(fee_rate) + fee_rate = fee_rate * STEP_SIZE + + buckets.append(INFINITE_FEE_RATE) + return buckets + + class FeeTracker: - sorted_buckets: SortedDict short_horizon: FeeStat med_horizon: FeeStat long_horizon: FeeStat @@ -413,30 +423,13 @@ class FeeTracker: def __init__(self, fee_store: FeeStore): self.log = logging.Logger(__name__) - self.sorted_buckets = SortedDict() - self.buckets = [] self.latest_seen_height = uint32(0) self.first_recorded_height = uint32(0) self.fee_store = fee_store - fee_rate = 0.0 - index = 0 - - while fee_rate < MAX_FEE_RATE: - self.buckets.append(fee_rate) - self.sorted_buckets[fee_rate] = index - if fee_rate == 0: - fee_rate = INITIAL_STEP - else: - fee_rate = fee_rate * STEP_SIZE - index += 1 - self.buckets.append(INFINITE_FEE_RATE) - self.sorted_buckets[INFINITE_FEE_RATE] = index - - assert len(self.sorted_buckets.keys()) == len(self.buckets) + self.buckets = init_buckets() self.short_horizon = FeeStat( self.buckets, - self.sorted_buckets, SHORT_BLOCK_PERIOD, SHORT_DECAY, SHORT_SCALE, @@ -445,7 +438,6 @@ class FeeTracker: ) self.med_horizon = FeeStat( self.buckets, - self.sorted_buckets, MED_BLOCK_PERIOD, MED_DECAY, MED_SCALE, @@ -454,7 +446,6 @@ class FeeTracker: ) self.long_horizon = FeeStat( self.buckets, - self.sorted_buckets, LONG_BLOCK_PERIOD, LONG_DECAY, LONG_SCALE, @@ -525,14 +516,14 @@ class FeeTracker: self.log.info(f"Processing Item from pending pool: cost={item.cost} fee={item.fee}") fee_rate = item.fee_per_cost * 1000 - bucket_index: int = get_bucket_index(self.sorted_buckets, fee_rate) + bucket_index: int = get_bucket_index(self.buckets, fee_rate) self.short_horizon.new_mempool_tx(self.latest_seen_height, bucket_index) self.med_horizon.new_mempool_tx(self.latest_seen_height, bucket_index) self.long_horizon.new_mempool_tx(self.latest_seen_height, bucket_index) def remove_tx(self, item: MempoolItem) -> None: - bucket_index = get_bucket_index(self.sorted_buckets, item.fee_per_cost * 1000) + bucket_index = get_bucket_index(self.buckets, item.fee_per_cost * 1000) self.short_horizon.remove_tx(self.latest_seen_height, item, bucket_index) self.med_horizon.remove_tx(self.latest_seen_height, item, bucket_index) self.long_horizon.remove_tx(self.latest_seen_height, item, bucket_index) diff --git a/tests/fee_estimation/test_fee_estimation_unit_tests.py b/tests/fee_estimation/test_fee_estimation_unit_tests.py index 79ed5c9d21..0c8525b3a6 100644 --- a/tests/fee_estimation/test_fee_estimation_unit_tests.py +++ b/tests/fee_estimation/test_fee_estimation_unit_tests.py @@ -3,12 +3,15 @@ from __future__ import annotations import logging from typing import List +import pytest from chia_rs import Coin from chia.consensus.cost_calculator import NPCResult from chia.full_node.bitcoin_fee_estimator import create_bitcoin_fee_estimator from chia.full_node.fee_estimation import FeeBlockInfo +from chia.full_node.fee_estimator_constants import INFINITE_FEE_RATE, INITIAL_STEP from chia.full_node.fee_estimator_interface import FeeEstimatorInterface +from chia.full_node.fee_tracker import get_bucket_index, init_buckets from chia.simulator.block_tools import test_constants from chia.simulator.wallet_tools import WalletTool from chia.types.clvm_cost import CLVMCost @@ -145,3 +148,49 @@ def test_fee_estimation_inception() -> None: # Confirm that estimates start after block 4 assert e1 == [0, 0, 0, 2, 2, 2, 2] + + +def test_init_buckets() -> None: + buckets = init_buckets() + assert len(buckets) > 1 + assert buckets[0] == INITIAL_STEP + assert buckets[-1] == INFINITE_FEE_RATE + + +def test_get_bucket_index_empty_buckets() -> None: + buckets: List[float] = [] + for rate in [0.5, 1.0, 2.0]: + with pytest.raises(RuntimeError): + a = get_bucket_index(buckets, rate) + log.warning(a) + + +def test_get_bucket_index_fee_rate_too_high() -> None: + buckets = [0.5, 1.0, 2.0] + index = get_bucket_index(buckets, 3.0) + assert index == len(buckets) - 1 + + +def test_get_bucket_index_single_entry() -> None: + """Test single entry with low, equal and high keys""" + from sys import float_info + + e = float_info.epsilon * 10 + buckets = [1.0] + print() + print(buckets) + for rate, expected_index in ((0.5, 0), (1.0 - e, 0), (1.5, 0)): + result_index = get_bucket_index(buckets, rate) + print(rate, expected_index, result_index) + assert expected_index == result_index + + +def test_get_bucket_index() -> None: + from sys import float_info + + e = float_info.epsilon * 10 + buckets = [1.0, 2.0] + + for rate, expected_index in ((0.5, 0), (1.0 - e, 0), (1.5, 0), (2.0 - e, 0), (2.0 + e, 1), (2.1, 1)): + result_index = get_bucket_index(buckets, rate) + assert result_index == expected_index