Refactor get_bucket_index (#14270)
This commit is contained in:
parent
f3a709e5e8
commit
d658de4e08
2 changed files with 21 additions and 22 deletions
|
@ -5,7 +5,13 @@ from dataclasses import dataclass, field
|
|||
|
||||
from chia.full_node.fee_estimate import FeeEstimate, FeeEstimateGroup
|
||||
from chia.full_node.fee_estimation import FeeMempoolInfo
|
||||
from chia.full_node.fee_tracker import BucketResult, EstimateResult, FeeTracker, get_estimate_time_intervals
|
||||
from chia.full_node.fee_tracker import (
|
||||
BucketResult,
|
||||
EstimateResult,
|
||||
FeeTracker,
|
||||
get_bucket_index,
|
||||
get_estimate_time_intervals,
|
||||
)
|
||||
from chia.types.fee_rate import FeeRate
|
||||
from chia.util.ints import uint32, uint64
|
||||
|
||||
|
@ -33,7 +39,7 @@ class SmartFeeEstimator:
|
|||
# get_bucket_index returns left (-1) bucket (-1). Start value is already -1
|
||||
# We want +1 from the lowest bucket it failed at. Thus +3
|
||||
max_val = len(self.fee_tracker.buckets) - 1
|
||||
start_index = min(self.fee_tracker.get_bucket_index(fail_bucket.start) + 3, max_val)
|
||||
start_index = min(get_bucket_index(self.fee_tracker.sorted_buckets, fail_bucket.start) + 3, max_val)
|
||||
|
||||
fee_val: float = self.fee_tracker.buckets[start_index]
|
||||
return fee_val
|
||||
|
|
|
@ -61,6 +61,16 @@ def get_estimate_time_intervals() -> List[uint64]:
|
|||
return [uint64(blocks * SECONDS_PER_BLOCK) for blocks in get_estimate_block_intervals()]
|
||||
|
||||
|
||||
def get_bucket_index(sorted_buckets: SortedDict, fee_rate: float) -> int:
|
||||
if fee_rate in sorted_buckets:
|
||||
bucket_index = sorted_buckets[fee_rate]
|
||||
else:
|
||||
# Choose the bucket to the left if we do not have exactly this fee rate
|
||||
bucket_index = sorted_buckets.bisect_left(fee_rate) - 1
|
||||
|
||||
return int(bucket_index)
|
||||
|
||||
|
||||
# Implementation of bitcoin core fee estimation algorithm
|
||||
# https://gist.github.com/morcos/d3637f015bc4e607e1fd10d8351e9f41
|
||||
class FeeStat: # TxConfirmStats
|
||||
|
@ -133,15 +143,6 @@ class FeeStat: # TxConfirmStats
|
|||
|
||||
self.old_unconfirmed_txs = [0 for _ in range(0, len(buckets))]
|
||||
|
||||
def get_bucket_index(self, fee_rate: float) -> int:
|
||||
if fee_rate in self.sorted_buckets:
|
||||
bucket_index = self.sorted_buckets[fee_rate]
|
||||
else:
|
||||
# Choose the bucket to the left if we do not have exactly this fee rate
|
||||
bucket_index = self.sorted_buckets.bisect_left(fee_rate) - 1
|
||||
|
||||
return int(bucket_index)
|
||||
|
||||
def tx_confirmed(self, blocks_to_confirm: int, item: MempoolItem) -> None:
|
||||
if blocks_to_confirm < 1:
|
||||
raise ValueError("tx_confirmed called with < 1 block to confirm")
|
||||
|
@ -149,7 +150,7 @@ class FeeStat: # TxConfirmStats
|
|||
periods_to_confirm = int((blocks_to_confirm + self.scale - 1) / self.scale)
|
||||
|
||||
fee_rate = item.fee_per_cost * 1000
|
||||
bucket_index = self.get_bucket_index(fee_rate)
|
||||
bucket_index = get_bucket_index(self.sorted_buckets, fee_rate)
|
||||
|
||||
for i in range(periods_to_confirm, len(self.confirmed_average)):
|
||||
self.confirmed_average[i - 1][bucket_index] += 1
|
||||
|
@ -172,7 +173,7 @@ class FeeStat: # TxConfirmStats
|
|||
self.unconfirmed_txs[block_height % len(self.unconfirmed_txs)][i] = 0
|
||||
|
||||
def new_mempool_tx(self, block_height: uint32, fee_rate: float) -> int:
|
||||
bucket_index: int = self.get_bucket_index(fee_rate)
|
||||
bucket_index: int = get_bucket_index(self.sorted_buckets, fee_rate)
|
||||
block_index = block_height % len(self.unconfirmed_txs)
|
||||
self.unconfirmed_txs[block_index][bucket_index] += 1
|
||||
return bucket_index
|
||||
|
@ -514,16 +515,8 @@ class FeeTracker:
|
|||
self.med_horizon.tx_confirmed(blocks_to_confirm, item)
|
||||
self.long_horizon.tx_confirmed(blocks_to_confirm, item)
|
||||
|
||||
def get_bucket_index(self, fee_rate: float) -> int:
|
||||
if fee_rate in self.sorted_buckets:
|
||||
bucket_index = self.sorted_buckets[fee_rate]
|
||||
else:
|
||||
bucket_index = self.sorted_buckets.bisect_left(fee_rate) - 1
|
||||
|
||||
return int(bucket_index)
|
||||
|
||||
def remove_tx(self, item: MempoolItem) -> None:
|
||||
bucket_index = self.get_bucket_index(item.fee_per_cost * 1000)
|
||||
bucket_index = get_bucket_index(self.sorted_buckets, item.fee_per_cost * 1000)
|
||||
self.short_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
|
||||
self.med_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
|
||||
self.long_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
|
||||
|
|
Loading…
Reference in a new issue