Refactor get_bucket_index (#14270)

This commit is contained in:
Adam Kelly 2023-01-05 08:56:23 -08:00 committed by GitHub
parent f3a709e5e8
commit d658de4e08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 22 deletions

View file

@ -5,7 +5,13 @@ from dataclasses import dataclass, field
from chia.full_node.fee_estimate import FeeEstimate, FeeEstimateGroup
from chia.full_node.fee_estimation import FeeMempoolInfo
from chia.full_node.fee_tracker import BucketResult, EstimateResult, FeeTracker, get_estimate_time_intervals
from chia.full_node.fee_tracker import (
BucketResult,
EstimateResult,
FeeTracker,
get_bucket_index,
get_estimate_time_intervals,
)
from chia.types.fee_rate import FeeRate
from chia.util.ints import uint32, uint64
@ -33,7 +39,7 @@ class SmartFeeEstimator:
# get_bucket_index returns left (-1) bucket (-1). Start value is already -1
# We want +1 from the lowest bucket it failed at. Thus +3
max_val = len(self.fee_tracker.buckets) - 1
start_index = min(self.fee_tracker.get_bucket_index(fail_bucket.start) + 3, max_val)
start_index = min(get_bucket_index(self.fee_tracker.sorted_buckets, fail_bucket.start) + 3, max_val)
fee_val: float = self.fee_tracker.buckets[start_index]
return fee_val

View file

@ -61,6 +61,16 @@ def get_estimate_time_intervals() -> List[uint64]:
return [uint64(blocks * SECONDS_PER_BLOCK) for blocks in get_estimate_block_intervals()]
def get_bucket_index(sorted_buckets: SortedDict, fee_rate: float) -> int:
if fee_rate in sorted_buckets:
bucket_index = sorted_buckets[fee_rate]
else:
# Choose the bucket to the left if we do not have exactly this fee rate
bucket_index = sorted_buckets.bisect_left(fee_rate) - 1
return int(bucket_index)
# Implementation of bitcoin core fee estimation algorithm
# https://gist.github.com/morcos/d3637f015bc4e607e1fd10d8351e9f41
class FeeStat: # TxConfirmStats
@ -133,15 +143,6 @@ class FeeStat: # TxConfirmStats
self.old_unconfirmed_txs = [0 for _ in range(0, len(buckets))]
def get_bucket_index(self, fee_rate: float) -> int:
if fee_rate in self.sorted_buckets:
bucket_index = self.sorted_buckets[fee_rate]
else:
# Choose the bucket to the left if we do not have exactly this fee rate
bucket_index = self.sorted_buckets.bisect_left(fee_rate) - 1
return int(bucket_index)
def tx_confirmed(self, blocks_to_confirm: int, item: MempoolItem) -> None:
if blocks_to_confirm < 1:
raise ValueError("tx_confirmed called with < 1 block to confirm")
@ -149,7 +150,7 @@ class FeeStat: # TxConfirmStats
periods_to_confirm = int((blocks_to_confirm + self.scale - 1) / self.scale)
fee_rate = item.fee_per_cost * 1000
bucket_index = self.get_bucket_index(fee_rate)
bucket_index = get_bucket_index(self.sorted_buckets, fee_rate)
for i in range(periods_to_confirm, len(self.confirmed_average)):
self.confirmed_average[i - 1][bucket_index] += 1
@ -172,7 +173,7 @@ class FeeStat: # TxConfirmStats
self.unconfirmed_txs[block_height % len(self.unconfirmed_txs)][i] = 0
def new_mempool_tx(self, block_height: uint32, fee_rate: float) -> int:
bucket_index: int = self.get_bucket_index(fee_rate)
bucket_index: int = get_bucket_index(self.sorted_buckets, fee_rate)
block_index = block_height % len(self.unconfirmed_txs)
self.unconfirmed_txs[block_index][bucket_index] += 1
return bucket_index
@ -514,16 +515,8 @@ class FeeTracker:
self.med_horizon.tx_confirmed(blocks_to_confirm, item)
self.long_horizon.tx_confirmed(blocks_to_confirm, item)
def get_bucket_index(self, fee_rate: float) -> int:
if fee_rate in self.sorted_buckets:
bucket_index = self.sorted_buckets[fee_rate]
else:
bucket_index = self.sorted_buckets.bisect_left(fee_rate) - 1
return int(bucket_index)
def remove_tx(self, item: MempoolItem) -> None:
bucket_index = self.get_bucket_index(item.fee_per_cost * 1000)
bucket_index = get_bucket_index(self.sorted_buckets, item.fee_per_cost * 1000)
self.short_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
self.med_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
self.long_horizon.remove_tx(self.latest_seen_height, item, bucket_index)