xch-blockchain/chia/rpc/farmer_rpc_api.py

294 lines
12 KiB
Python

from __future__ import annotations
import dataclasses
import operator
from typing import Any, Callable, Dict, List, Optional
from typing_extensions import Protocol
from chia.farmer.farmer import Farmer
from chia.plot_sync.receiver import Receiver
from chia.protocols.harvester_protocol import Plot
from chia.rpc.rpc_server import Endpoint, EndpointResult
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.ints import uint32
from chia.util.paginator import Paginator
from chia.util.streamable import Streamable, streamable
from chia.util.ws_message import WsRpcMessage, create_payload_dict
class PaginatedRequestData(Protocol):
@property
def node_id(self) -> bytes32:
pass
@property
def page(self) -> uint32:
pass
@property
def page_size(self) -> uint32:
pass
@streamable
@dataclasses.dataclass(frozen=True)
class FilterItem(Streamable):
key: str
value: Optional[str]
@streamable
@dataclasses.dataclass(frozen=True)
class PlotInfoRequestData(Streamable):
node_id: bytes32
page: uint32
page_size: uint32
filter: List[FilterItem] = dataclasses.field(default_factory=list)
sort_key: str = "filename"
reverse: bool = False
@streamable
@dataclasses.dataclass(frozen=True)
class PlotPathRequestData(Streamable):
node_id: bytes32
page: uint32
page_size: uint32
filter: List[str] = dataclasses.field(default_factory=list)
reverse: bool = False
def paginated_plot_request(source: List[Any], request: PaginatedRequestData) -> Dict[str, object]:
paginator: Paginator = Paginator(source, request.page_size)
return {
"node_id": request.node_id.hex(),
"page": request.page,
"page_count": paginator.page_count(),
"total_count": len(source),
"plots": paginator.get_page(request.page),
}
def plot_matches_filter(plot: Plot, filter_item: FilterItem) -> bool:
plot_attribute = getattr(plot, filter_item.key)
if filter_item.value is None:
return plot_attribute is None
else:
return filter_item.value in str(plot_attribute)
class FarmerRpcApi:
def __init__(self, farmer: Farmer):
self.service = farmer
self.service_name = "chia_farmer"
def get_routes(self) -> Dict[str, Endpoint]:
return {
"/get_signage_point": self.get_signage_point,
"/get_signage_points": self.get_signage_points,
"/get_reward_targets": self.get_reward_targets,
"/set_reward_targets": self.set_reward_targets,
"/get_pool_state": self.get_pool_state,
"/set_payout_instructions": self.set_payout_instructions,
"/get_harvesters": self.get_harvesters,
"/get_harvesters_summary": self.get_harvesters_summary,
"/get_harvester_plots_valid": self.get_harvester_plots_valid,
"/get_harvester_plots_invalid": self.get_harvester_plots_invalid,
"/get_harvester_plots_keys_missing": self.get_harvester_plots_keys_missing,
"/get_harvester_plots_duplicates": self.get_harvester_plots_duplicates,
"/get_pool_login_link": self.get_pool_login_link,
}
async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]]) -> List[WsRpcMessage]:
payloads = []
if change_data is None:
# TODO: maybe something better?
pass
elif change == "new_signage_point":
sp_hash = change_data["sp_hash"]
data = await self.get_signage_point({"sp_hash": sp_hash.hex()})
payloads.append(
create_payload_dict(
"new_signage_point",
data,
self.service_name,
"wallet_ui",
)
)
elif change == "new_farming_info":
payloads.append(
create_payload_dict(
"new_farming_info",
change_data,
self.service_name,
"wallet_ui",
)
)
elif change == "harvester_update":
payloads.append(
create_payload_dict(
"harvester_update",
change_data,
self.service_name,
"wallet_ui",
)
)
elif change == "harvester_removed":
payloads.append(
create_payload_dict(
"harvester_removed",
change_data,
self.service_name,
"wallet_ui",
)
)
elif change == "submitted_partial":
payloads.append(
create_payload_dict(
"submitted_partial",
change_data,
self.service_name,
"metrics",
)
)
elif change == "proof":
payloads.append(
create_payload_dict(
"proof",
change_data,
self.service_name,
"metrics",
)
)
return payloads
async def get_signage_point(self, request: Dict) -> EndpointResult:
sp_hash = hexstr_to_bytes(request["sp_hash"])
for _, sps in self.service.sps.items():
for sp in sps:
if sp.challenge_chain_sp == sp_hash:
pospaces = self.service.proofs_of_space.get(sp.challenge_chain_sp, [])
return {
"signage_point": {
"challenge_hash": sp.challenge_hash,
"challenge_chain_sp": sp.challenge_chain_sp,
"reward_chain_sp": sp.reward_chain_sp,
"difficulty": sp.difficulty,
"sub_slot_iters": sp.sub_slot_iters,
"signage_point_index": sp.signage_point_index,
},
"proofs": pospaces,
}
raise ValueError(f"Signage point {sp_hash.hex()} not found")
async def get_signage_points(self, _: Dict) -> EndpointResult:
result: List[Dict[str, Any]] = []
for sps in self.service.sps.values():
for sp in sps:
pospaces = self.service.proofs_of_space.get(sp.challenge_chain_sp, [])
result.append(
{
"signage_point": {
"challenge_hash": sp.challenge_hash,
"challenge_chain_sp": sp.challenge_chain_sp,
"reward_chain_sp": sp.reward_chain_sp,
"difficulty": sp.difficulty,
"sub_slot_iters": sp.sub_slot_iters,
"signage_point_index": sp.signage_point_index,
},
"proofs": pospaces,
}
)
return {"signage_points": result}
async def get_reward_targets(self, request: Dict) -> EndpointResult:
search_for_private_key = request["search_for_private_key"]
max_ph_to_search = request.get("max_ph_to_search", 500)
return await self.service.get_reward_targets(search_for_private_key, max_ph_to_search)
async def set_reward_targets(self, request: Dict) -> EndpointResult:
farmer_target, pool_target = None, None
if "farmer_target" in request:
farmer_target = request["farmer_target"]
if "pool_target" in request:
pool_target = request["pool_target"]
self.service.set_reward_targets(farmer_target, pool_target)
return {}
def get_pool_contract_puzzle_hash_plot_count(self, pool_contract_puzzle_hash: bytes32) -> int:
plot_count: int = 0
for receiver in self.service.plot_sync_receivers.values():
plot_count += sum(
plot.pool_contract_puzzle_hash == pool_contract_puzzle_hash for plot in receiver.plots().values()
)
return plot_count
async def get_pool_state(self, _: Dict) -> EndpointResult:
pools_list = []
for p2_singleton_puzzle_hash, pool_dict in self.service.pool_state.items():
pool_state = pool_dict.copy()
pool_state["p2_singleton_puzzle_hash"] = p2_singleton_puzzle_hash.hex()
pool_state["plot_count"] = self.get_pool_contract_puzzle_hash_plot_count(p2_singleton_puzzle_hash)
pools_list.append(pool_state)
return {"pool_state": pools_list}
async def set_payout_instructions(self, request: Dict) -> EndpointResult:
launcher_id: bytes32 = bytes32.from_hexstr(request["launcher_id"])
await self.service.set_payout_instructions(launcher_id, request["payout_instructions"])
return {}
async def get_harvesters(self, _: Dict) -> EndpointResult:
return await self.service.get_harvesters(False)
async def get_harvesters_summary(self, _: Dict[str, object]) -> EndpointResult:
return await self.service.get_harvesters(True)
async def get_harvester_plots_valid(self, request_dict: Dict[str, object]) -> EndpointResult:
# TODO: Consider having a extra List[PlotInfo] in Receiver to avoid rebuilding the list for each call
request = PlotInfoRequestData.from_json_dict(request_dict)
plot_list = list(self.service.get_receiver(request.node_id).plots().values())
# Apply filter
plot_list = [
plot for plot in plot_list if all(plot_matches_filter(plot, filter_item) for filter_item in request.filter)
]
restricted_sort_keys: List[str] = ["pool_contract_puzzle_hash", "pool_public_key", "plot_public_key"]
# Apply sort_key and reverse if sort_key is not restricted
if request.sort_key in restricted_sort_keys:
raise KeyError(f"Can't sort by optional attributes: {restricted_sort_keys}")
# Sort by plot_id also by default since its unique
plot_list = sorted(plot_list, key=operator.attrgetter(request.sort_key, "plot_id"), reverse=request.reverse)
return paginated_plot_request(plot_list, request)
def paginated_plot_path_request(
self, source_func: Callable[[Receiver], List[str]], request_dict: Dict[str, object]
) -> Dict[str, object]:
request: PlotPathRequestData = PlotPathRequestData.from_json_dict(request_dict)
receiver = self.service.get_receiver(request.node_id)
source = source_func(receiver)
# Apply filter
source = [plot for plot in source if all(filter_item in plot for filter_item in request.filter)]
# Apply reverse
source = sorted(source, reverse=request.reverse)
return paginated_plot_request(source, request)
async def get_harvester_plots_invalid(self, request_dict: Dict[str, object]) -> EndpointResult:
return self.paginated_plot_path_request(Receiver.invalid, request_dict)
async def get_harvester_plots_keys_missing(self, request_dict: Dict[str, object]) -> EndpointResult:
return self.paginated_plot_path_request(Receiver.keys_missing, request_dict)
async def get_harvester_plots_duplicates(self, request_dict: Dict[str, object]) -> EndpointResult:
return self.paginated_plot_path_request(Receiver.duplicates, request_dict)
async def get_pool_login_link(self, request: Dict) -> EndpointResult:
launcher_id: bytes32 = bytes32(hexstr_to_bytes(request["launcher_id"]))
login_link: Optional[str] = await self.service.generate_login_link(launcher_id)
if login_link is None:
raise ValueError(f"Failed to generate login link for {launcher_id.hex()}")
return {"login_link": login_link}