diff --git a/chia/full_node/weight_proof.py b/chia/full_node/weight_proof.py index 7b18c646e6..6899d9a622 100644 --- a/chia/full_node/weight_proof.py +++ b/chia/full_node/weight_proof.py @@ -2,6 +2,7 @@ import asyncio import dataclasses import logging import math +import pathlib import random from concurrent.futures.process import ProcessPoolExecutor from typing import Dict, List, Optional, Tuple @@ -23,7 +24,7 @@ from chia.types.blockchain_format.classgroup import ClassgroupElement from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.blockchain_format.slots import ChallengeChainSubSlot, RewardChainSubSlot from chia.types.blockchain_format.sub_epoch_summary import SubEpochSummary -from chia.types.blockchain_format.vdf import VDFInfo +from chia.types.blockchain_format.vdf import VDFInfo, VDFProof from chia.types.end_of_slot_bundle import EndOfSubSlotBundle from chia.types.header_block import HeaderBlock from chia.types.weight_proof import ( @@ -58,6 +59,7 @@ class WeightProofHandler: self.constants = constants self.blockchain = blockchain self.lock = asyncio.Lock() + self._num_processes = 4 async def get_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]: @@ -107,10 +109,7 @@ class WeightProofHandler: return None summary_heights = self.blockchain.get_ses_heights() - # TODO: address hint error and remove ignore - # error: Argument 1 to "get_block_record_from_db" of "BlockchainInterface" has incompatible type - # "Optional[bytes32]"; expected "bytes32" [arg-type] - prev_ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(uint32(0))) # type: ignore[arg-type] # noqa: E501 + prev_ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(uint32(0))) if prev_ses_block is None: return None sub_epoch_data = self.get_sub_epoch_data(tip_rec.height, summary_heights) @@ -140,10 +139,7 @@ class WeightProofHandler: if _sample_sub_epoch(prev_ses_block.weight, ses_block.weight, weight_to_check): # type: ignore sample_n += 1 - # TODO: address hint error and remove ignore - # error: Argument 1 to "get_sub_epoch_challenge_segments" of "BlockchainInterface" has - # incompatible type "bytes32"; expected "uint32" [arg-type] - segments = await self.blockchain.get_sub_epoch_challenge_segments(ses_block.header_hash) # type: ignore[arg-type] # noqa: E501 + segments = await self.blockchain.get_sub_epoch_challenge_segments(ses_block.header_hash) if segments is None: segments = await self.__create_sub_epoch_segments(ses_block, prev_ses_block, uint32(sub_epoch_n)) if segments is None: @@ -151,10 +147,7 @@ class WeightProofHandler: f"failed while building segments for sub epoch {sub_epoch_n}, ses height {ses_height} " ) return None - # TODO: address hint error and remove ignore - # error: Argument 1 to "persist_sub_epoch_challenge_segments" of "BlockchainInterface" has - # incompatible type "bytes32"; expected "uint32" [arg-type] - await self.blockchain.persist_sub_epoch_challenge_segments(ses_block.header_hash, segments) # type: ignore[arg-type] # noqa: E501 + await self.blockchain.persist_sub_epoch_challenge_segments(ses_block.header_hash, segments) log.debug(f"sub epoch {sub_epoch_n} has {len(segments)} segments") sub_epoch_segments.extend(segments) prev_ses_block = ses_block @@ -195,10 +188,7 @@ class WeightProofHandler: if curr_height == 0: break # add to needed reward chain recent blocks - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, HeaderBlock]"; expected type - # "bytes32" [index] - header_block = headers[self.blockchain.height_to_hash(curr_height)] # type: ignore[index] + header_block = headers[self.blockchain.height_to_hash(curr_height)] block_rec = blocks[header_block.header_hash] if header_block is None: log.error("creating recent chain failed") @@ -206,13 +196,10 @@ class WeightProofHandler: recent_chain.insert(0, header_block) if block_rec.sub_epoch_summary_included: ses_count += 1 - curr_height = uint32(curr_height - 1) + curr_height = uint32(curr_height - 1) # type: ignore blocks_n += 1 - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, HeaderBlock]"; expected type "bytes32" - # [index] - header_block = headers[self.blockchain.height_to_hash(curr_height)] # type: ignore[index] + header_block = headers[self.blockchain.height_to_hash(curr_height)] recent_chain.insert(0, header_block) log.info( @@ -309,10 +296,7 @@ class WeightProofHandler: first = False else: height = height + uint32(1) # type: ignore - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, HeaderBlock]"; expected type - # "bytes32" [index] - curr = header_blocks[self.blockchain.height_to_hash(height)] # type: ignore[index] + curr = header_blocks[self.blockchain.height_to_hash(height)] if curr is None: return None log.debug(f"next sub epoch starts at {height}") @@ -331,10 +315,7 @@ class WeightProofHandler: if end - curr_rec.height == batch_size - 1: blocks = await self.blockchain.get_block_records_in_range(curr_rec.height - batch_size, curr_rec.height) end = curr_rec.height - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, BlockRecord]"; expected type - # "bytes32" [index] - curr_rec = blocks[self.blockchain.height_to_hash(uint32(curr_rec.height - 1))] # type: ignore[index] + curr_rec = blocks[self.blockchain.height_to_hash(uint32(curr_rec.height - 1))] return curr_rec.height async def _create_challenge_segment( @@ -447,10 +428,7 @@ class WeightProofHandler: curr.total_iters, ) tmp_sub_slots_data.append(ssd) - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, HeaderBlock]"; expected type - # "bytes32" [index] - curr = header_blocks[self.blockchain.height_to_hash(uint32(curr.height + 1))] # type: ignore[index] + curr = header_blocks[self.blockchain.height_to_hash(uint32(curr.height + 1))] if len(tmp_sub_slots_data) > 0: sub_slots_data.extend(tmp_sub_slots_data) @@ -479,10 +457,7 @@ class WeightProofHandler: ) -> Tuple[Optional[List[SubSlotData]], uint32]: # gets all vdfs first sub slot after challenge block to last sub slot log.debug(f"slot end vdf start height {start_height}") - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, HeaderBlock]"; expected type "bytes32" - # [index] - curr = header_blocks[self.blockchain.height_to_hash(start_height)] # type: ignore[index] + curr = header_blocks[self.blockchain.height_to_hash(start_height)] curr_header_hash = curr.header_hash sub_slots_data: List[SubSlotData] = [] tmp_sub_slots_data: List[SubSlotData] = [] @@ -501,10 +476,7 @@ class WeightProofHandler: tmp_sub_slots_data = [] tmp_sub_slots_data.append(self.handle_block_vdfs(curr, blocks)) - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[bytes32]" for "Dict[bytes32, HeaderBlock]"; expected type - # "bytes32" [index] - curr = header_blocks[self.blockchain.height_to_hash(uint32(curr.height + 1))] # type: ignore[index] + curr = header_blocks[self.blockchain.height_to_hash(uint32(curr.height + 1))] curr_header_hash = curr.header_hash if len(tmp_sub_slots_data) > 0: @@ -619,30 +591,42 @@ class WeightProofHandler: log.error("failed weight proof sub epoch sample validation") return False, uint32(0), [] - executor = ProcessPoolExecutor(1) + executor = ProcessPoolExecutor(4) constants, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes( self.constants, summaries, weight_proof ) - segment_validation_task = asyncio.get_running_loop().run_in_executor( - executor, _validate_sub_epoch_segments, constants, rng, wp_segment_bytes, summary_bytes - ) recent_blocks_validation_task = asyncio.get_running_loop().run_in_executor( executor, _validate_recent_blocks, constants, wp_recent_chain_bytes, summary_bytes ) - valid_segment_task = segment_validation_task + segments_validated, vdfs_to_validate = _validate_sub_epoch_segments( + constants, rng, wp_segment_bytes, summary_bytes + ) + if not segments_validated: + return False, uint32(0), [] + + vdf_chunks = chunks(vdfs_to_validate, self._num_processes) + vdf_tasks = [] + for chunk in vdf_chunks: + byte_chunks = [] + for vdf_proof, classgroup, vdf_info in chunk: + byte_chunks.append((bytes(vdf_proof), bytes(classgroup), bytes(vdf_info))) + + vdf_task = asyncio.get_running_loop().run_in_executor(executor, _validate_vdf_batch, constants, byte_chunks) + vdf_tasks.append(vdf_task) + + for vdf_task in vdf_tasks: + validated = await vdf_task + if not validated: + return False, uint32(0), [] + valid_recent_blocks_task = recent_blocks_validation_task valid_recent_blocks = await valid_recent_blocks_task if not valid_recent_blocks: log.error("failed validating weight proof recent blocks") return False, uint32(0), [] - valid_segments = await valid_segment_task - if not valid_segments: - log.error("failed validating weight proof sub epoch segments") - return False, uint32(0), [] - return True, self.get_fork_point(summaries), summaries def get_fork_point(self, received_summaries: List[SubEpochSummary]) -> uint32: @@ -837,6 +821,11 @@ def handle_end_of_slot( ) +def chunks(some_list, chunk_size): + chunk_size = max(1, chunk_size) + return (some_list[i : i + chunk_size] for i in range(0, len(some_list), chunk_size)) + + def compress_segments(full_segment_index, segments: List[SubEpochChallengeSegment]) -> List[SubEpochChallengeSegment]: compressed_segments = [] compressed_segments.append(segments[0]) @@ -961,6 +950,7 @@ def _validate_sub_epoch_segments( prev_ses: Optional[SubEpochSummary] = None segments_by_sub_epoch = map_segments_by_sub_epoch(sub_epoch_segments.challenge_segments) curr_ssi = constants.SUB_SLOT_ITERS_STARTING + vdfs_to_validate = [] for sub_epoch_n, segments in segments_by_sub_epoch.items(): prev_ssi = curr_ssi curr_difficulty, curr_ssi = _get_curr_diff_ssi(constants, sub_epoch_n, summaries) @@ -975,9 +965,10 @@ def _validate_sub_epoch_segments( log.error(f"failed reward_chain_hash validation sub_epoch {sub_epoch_n}") return False for idx, segment in enumerate(segments): - valid_segment, ip_iters, slot_iters, slots = _validate_segment( + valid_segment, ip_iters, slot_iters, slots, vdf_list = _validate_segment( constants, segment, curr_ssi, prev_ssi, curr_difficulty, prev_ses, idx == 0, sampled_seg_index == idx ) + vdfs_to_validate.extend(vdf_list) if not valid_segment: log.error(f"failed to validate sub_epoch {segment.sub_epoch_n} segment {idx} slots") return False @@ -986,7 +977,7 @@ def _validate_sub_epoch_segments( total_slot_iters += slot_iters total_slots += slots total_ip_iters += ip_iters - return True + return True, vdfs_to_validate def _validate_segment( @@ -998,37 +989,40 @@ def _validate_segment( ses: Optional[SubEpochSummary], first_segment_in_se: bool, sampled: bool, -) -> Tuple[bool, int, int, int]: +) -> Tuple[bool, int, int, int, List[Tuple[VDFProof, ClassgroupElement, VDFInfo]]]: ip_iters, slot_iters, slots = 0, 0, 0 after_challenge = False + to_validate = [] for idx, sub_slot_data in enumerate(segment.sub_slots): if sampled and sub_slot_data.is_challenge(): after_challenge = True required_iters = __validate_pospace(constants, segment, idx, curr_difficulty, ses, first_segment_in_se) if required_iters is None: - return False, uint64(0), uint64(0), uint64(0) + return False, uint64(0), uint64(0), uint64(0), [] assert sub_slot_data.signage_point_index is not None - ip_iters = ip_iters + calculate_ip_iters( + ip_iters = ip_iters + calculate_ip_iters( # type: ignore constants, curr_ssi, sub_slot_data.signage_point_index, required_iters ) - if not _validate_challenge_block_vdfs(constants, idx, segment.sub_slots, curr_ssi): - log.error(f"failed to validate challenge slot {idx} vdfs") - return False, uint64(0), uint64(0), uint64(0) + vdf_list = _get_challenge_block_vdfs(constants, idx, segment.sub_slots, curr_ssi) + to_validate.extend(vdf_list) elif sampled and after_challenge: - if not _validate_sub_slot_data(constants, idx, segment.sub_slots, curr_ssi): + validated, vdf_list = _validate_sub_slot_data(constants, idx, segment.sub_slots, curr_ssi) + if not validated: log.error(f"failed to validate sub slot data {idx} vdfs") - return False, uint64(0), uint64(0), uint64(0) - slot_iters = slot_iters + curr_ssi - slots = slots + uint64(1) - return True, ip_iters, slot_iters, slots + return False, uint64(0), uint64(0), uint64(0), [] + to_validate.extend(vdf_list) + slot_iters = slot_iters + curr_ssi # type: ignore + slots = slots + uint64(1) # type: ignore + return True, ip_iters, slot_iters, slots, to_validate -def _validate_challenge_block_vdfs( +def _get_challenge_block_vdfs( constants: ConsensusConstants, sub_slot_idx: int, sub_slots: List[SubSlotData], ssi: uint64, -) -> bool: +) -> List[Tuple[VDFProof, ClassgroupElement, VDFInfo]]: + to_validate = [] sub_slot_data = sub_slots[sub_slot_idx] if sub_slot_data.cc_signage_point is not None and sub_slot_data.cc_sp_vdf_info: assert sub_slot_data.signage_point_index @@ -1039,9 +1033,8 @@ def _validate_challenge_block_vdfs( sp_input = sub_slot_data_vdf_input( constants, sub_slot_data, sub_slot_idx, sub_slots, is_overflow, prev_ssd.is_end_of_slot(), ssi ) - if not sub_slot_data.cc_signage_point.is_valid(constants, sp_input, sub_slot_data.cc_sp_vdf_info): - log.error(f"failed to validate challenge chain signage point 2 {sub_slot_data.cc_sp_vdf_info}") - return False + to_validate.append((sub_slot_data.cc_signage_point, sp_input, sub_slot_data.cc_sp_vdf_info)) + assert sub_slot_data.cc_infusion_point assert sub_slot_data.cc_ip_vdf_info ip_input = ClassgroupElement.get_default_element() @@ -1057,10 +1050,9 @@ def _validate_challenge_block_vdfs( cc_ip_vdf_info = VDFInfo( sub_slot_data.cc_ip_vdf_info.challenge, ip_vdf_iters, sub_slot_data.cc_ip_vdf_info.output ) - if not sub_slot_data.cc_infusion_point.is_valid(constants, ip_input, cc_ip_vdf_info): - log.error(f"failed to validate challenge chain infusion point {sub_slot_data.cc_ip_vdf_info}") - return False - return True + to_validate.append((sub_slot_data.cc_infusion_point, ip_input, cc_ip_vdf_info)) + + return to_validate def _validate_sub_slot_data( @@ -1068,10 +1060,12 @@ def _validate_sub_slot_data( sub_slot_idx: int, sub_slots: List[SubSlotData], ssi: uint64, -) -> bool: +) -> Tuple[bool, List[Tuple[VDFProof, ClassgroupElement, VDFInfo]]]: + sub_slot_data = sub_slots[sub_slot_idx] assert sub_slot_idx > 0 prev_ssd = sub_slots[sub_slot_idx - 1] + to_validate = [] if sub_slot_data.is_end_of_slot(): if sub_slot_data.icc_slot_end is not None: input = ClassgroupElement.get_default_element() @@ -1079,9 +1073,7 @@ def _validate_sub_slot_data( assert prev_ssd.icc_ip_vdf_info input = prev_ssd.icc_ip_vdf_info.output assert sub_slot_data.icc_slot_end_info - if not sub_slot_data.icc_slot_end.is_valid(constants, input, sub_slot_data.icc_slot_end_info, None): - log.error(f"failed icc slot end validation {sub_slot_data.icc_slot_end_info} ") - return False + to_validate.append((sub_slot_data.icc_slot_end, input, sub_slot_data.icc_slot_end_info)) assert sub_slot_data.cc_slot_end_info assert sub_slot_data.cc_slot_end input = ClassgroupElement.get_default_element() @@ -1090,7 +1082,7 @@ def _validate_sub_slot_data( input = prev_ssd.cc_ip_vdf_info.output if not sub_slot_data.cc_slot_end.is_valid(constants, input, sub_slot_data.cc_slot_end_info): log.error(f"failed cc slot end validation {sub_slot_data.cc_slot_end_info}") - return False + return False, [] else: # find end of slot idx = sub_slot_idx @@ -1101,7 +1093,7 @@ def _validate_sub_slot_data( assert curr_slot.cc_slot_end if curr_slot.cc_slot_end.normalized_to_identity is True: log.debug(f"skip intermediate vdfs slot {sub_slot_idx}") - return True + return True, to_validate else: break idx += 1 @@ -1109,10 +1101,7 @@ def _validate_sub_slot_data( input = ClassgroupElement.get_default_element() if not prev_ssd.is_challenge() and prev_ssd.icc_ip_vdf_info is not None: input = prev_ssd.icc_ip_vdf_info.output - if not sub_slot_data.icc_infusion_point.is_valid(constants, input, sub_slot_data.icc_ip_vdf_info, None): - log.error(f"failed icc infusion point vdf validation {sub_slot_data.icc_slot_end_info} ") - return False - + to_validate.append((sub_slot_data.icc_infusion_point, input, sub_slot_data.icc_ip_vdf_info)) assert sub_slot_data.signage_point_index is not None if sub_slot_data.cc_signage_point: assert sub_slot_data.cc_sp_vdf_info @@ -1122,10 +1111,8 @@ def _validate_sub_slot_data( input = sub_slot_data_vdf_input( constants, sub_slot_data, sub_slot_idx, sub_slots, is_overflow, prev_ssd.is_end_of_slot(), ssi ) + to_validate.append((sub_slot_data.cc_signage_point, input, sub_slot_data.cc_sp_vdf_info)) - if not sub_slot_data.cc_signage_point.is_valid(constants, input, sub_slot_data.cc_sp_vdf_info): - log.error(f"failed cc signage point vdf validation {sub_slot_data.cc_sp_vdf_info}") - return False input = ClassgroupElement.get_default_element() assert sub_slot_data.cc_ip_vdf_info assert sub_slot_data.cc_infusion_point @@ -1139,10 +1126,9 @@ def _validate_sub_slot_data( cc_ip_vdf_info = VDFInfo( sub_slot_data.cc_ip_vdf_info.challenge, ip_vdf_iters, sub_slot_data.cc_ip_vdf_info.output ) - if not sub_slot_data.cc_infusion_point.is_valid(constants, input, cc_ip_vdf_info): - log.error(f"failed cc infusion point vdf validation {sub_slot_data.cc_slot_end_info}") - return False - return True + to_validate.append((sub_slot_data.cc_infusion_point, input, cc_ip_vdf_info)) + + return True, to_validate def sub_slot_data_vdf_input( @@ -1203,14 +1189,17 @@ def sub_slot_data_vdf_input( return cc_input -def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, summaries_bytes: List[bytes]) -> bool: - constants, summaries = bytes_to_vars(constants_dict, summaries_bytes) - recent_chain: RecentChainData = RecentChainData.from_bytes(recent_chain_bytes) +def validate_recent_blocks( + constants: ConsensusConstants, + recent_chain: RecentChainData, + summaries: List[SubEpochSummary], + shutdown_file_path: Optional[pathlib.Path] = None, +) -> Tuple[bool, List[bytes]]: sub_blocks = BlockCache({}) first_ses_idx = _get_ses_idx(recent_chain.recent_chain_data) ses_idx = len(summaries) - len(first_ses_idx) ssi: uint64 = constants.SUB_SLOT_ITERS_STARTING - diff: Optional[uint64] = constants.DIFFICULTY_STARTING + diff: uint64 = constants.DIFFICULTY_STARTING last_blocks_to_validate = 100 # todo remove cap after benchmarks for summary in summaries[:ses_idx]: if summary.new_sub_slot_iters is not None: @@ -1219,10 +1208,11 @@ def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, sum diff = summary.new_difficulty ses_blocks, sub_slots, transaction_blocks = 0, 0, 0 - challenge, prev_challenge = None, None + challenge, prev_challenge = recent_chain.recent_chain_data[0].reward_chain_block.pos_ss_cc_challenge_hash, None tip_height = recent_chain.recent_chain_data[-1].height prev_block_record = None deficit = uint8(0) + adjusted = False for idx, block in enumerate(recent_chain.recent_chain_data): required_iters = uint64(0) overflow = False @@ -1243,21 +1233,30 @@ def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, sum if (challenge is not None) and (prev_challenge is not None): overflow = is_overflow_block(constants, block.reward_chain_block.signage_point_index) + if not adjusted: + prev_block_record = dataclasses.replace( + prev_block_record, deficit=deficit % constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK + ) + assert prev_block_record is not None + sub_blocks.add_block_record(prev_block_record) + adjusted = True deficit = get_deficit(constants, deficit, prev_block_record, overflow, len(block.finished_sub_slots)) log.debug(f"wp, validate block {block.height}") if sub_slots > 2 and transaction_blocks > 11 and (tip_height - block.height < last_blocks_to_validate): - required_iters, error = validate_finished_header_block( + caluclated_required_iters, error = validate_finished_header_block( constants, sub_blocks, block, False, diff, ssi, ses_blocks > 2 ) if error is not None: log.error(f"block {block.header_hash} failed validation {error}") - return False + return False, [] + assert caluclated_required_iters is not None + required_iters = caluclated_required_iters else: required_iters = _validate_pospace_recent_chain( constants, block, challenge, diff, overflow, prev_challenge ) if required_iters is None: - return False + return False, [] curr_block_ses = None if not ses else summaries[ses_idx - 1] block_record = header_block_to_sub_block_record( @@ -1274,7 +1273,29 @@ def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, sum ses_blocks += 1 prev_block_record = block_record - return True + if shutdown_file_path is not None and not shutdown_file_path.is_file(): + log.info(f"cancelling block {block.header_hash} validation, shutdown requested") + return False, [] + + return True, [bytes(sub) for sub in sub_blocks._block_records.values()] + + +def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, summaries_bytes: List[bytes]) -> bool: + constants, summaries = bytes_to_vars(constants_dict, summaries_bytes) + recent_chain: RecentChainData = RecentChainData.from_bytes(recent_chain_bytes) + success, records = validate_recent_blocks(constants, recent_chain, summaries) + return success + + +def _validate_recent_blocks_and_get_records( + constants_dict: Dict, + recent_chain_bytes: bytes, + summaries_bytes: List[bytes], + shutdown_file_path: Optional[pathlib.Path] = None, +) -> Tuple[bool, List[bytes]]: + constants, summaries = bytes_to_vars(constants_dict, summaries_bytes) + recent_chain: RecentChainData = RecentChainData.from_bytes(recent_chain_bytes) + return validate_recent_blocks(constants, recent_chain, summaries, shutdown_file_path) def _validate_pospace_recent_chain( @@ -1386,7 +1407,7 @@ def __get_rc_sub_slot( new_diff = None if ses is None else ses.new_difficulty new_ssi = None if ses is None else ses.new_sub_slot_iters - ses_hash: Optional[bytes32] = None if ses is None else ses.get_hash() + ses_hash = None if ses is None else ses.get_hash() overflow = is_overflow_block(constants, first.signage_point_index) if overflow: if idx >= 2 and slots[idx - 2].cc_slot_end is not None and slots[idx - 1].cc_slot_end is not None: @@ -1473,7 +1494,7 @@ def _get_curr_diff_ssi(constants: ConsensusConstants, idx, summaries): return curr_difficulty, curr_ssi -def vars_to_bytes(constants, summaries, weight_proof): +def vars_to_bytes(constants: ConsensusConstants, summaries: List[SubEpochSummary], weight_proof: WeightProof): constants_dict = recurse_jsonify(dataclasses.asdict(constants)) wp_recent_chain_bytes = bytes(RecentChainData(weight_proof.recent_chain_data)) wp_segment_bytes = bytes(SubEpochSegments(weight_proof.sub_epoch_segments)) @@ -1524,13 +1545,13 @@ def _get_ses_idx(recent_reward_chain: List[HeaderBlock]) -> List[int]: def get_deficit( constants: ConsensusConstants, curr_deficit: uint8, - prev_block: BlockRecord, + prev_block: Optional[BlockRecord], overflow: bool, num_finished_sub_slots: int, ) -> uint8: if prev_block is None: if curr_deficit >= 1 and not (overflow and curr_deficit == constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK): - curr_deficit -= 1 + curr_deficit = uint8(curr_deficit - 1) return curr_deficit return calculate_deficit(constants, uint32(prev_block.height + 1), prev_block, overflow, num_finished_sub_slots) @@ -1617,3 +1638,22 @@ def validate_total_iters( total_iters = uint128(prev_b.total_iters - prev_b.cc_ip_vdf_info.number_of_iterations) total_iters = uint128(total_iters + sub_slot_data.cc_ip_vdf_info.number_of_iterations) return total_iters == sub_slot_data.total_iters + + +def _validate_vdf_batch( + constants_dict, vdf_list: List[Tuple[bytes, bytes, bytes]], shutdown_file_path: Optional[pathlib.Path] = None +): + constants: ConsensusConstants = dataclass_from_dict(ConsensusConstants, constants_dict) + + for vdf_proof_bytes, class_group_bytes, info in vdf_list: + vdf = VDFProof.from_bytes(vdf_proof_bytes) + class_group = ClassgroupElement.from_bytes(class_group_bytes) + vdf_info = VDFInfo.from_bytes(info) + if not vdf.is_valid(constants, class_group, vdf_info): + return False + + if shutdown_file_path is not None and not shutdown_file_path.is_file(): + log.info("cancelling VDF validation, shutdown requested") + return False + + return True diff --git a/chia/rpc/wallet_rpc_api.py b/chia/rpc/wallet_rpc_api.py index 0843bbac9e..f88fb3836b 100644 --- a/chia/rpc/wallet_rpc_api.py +++ b/chia/rpc/wallet_rpc_api.py @@ -1,8 +1,7 @@ import asyncio import logging -import time from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Set +from typing import Callable, Dict, List, Optional, Tuple, Set, Any from blspy import PrivateKey, G1Element @@ -20,14 +19,14 @@ from chia.util.ints import uint32, uint64 from chia.util.keychain import KeyringIsLocked, bytes_to_mnemonic, generate_mnemonic from chia.util.path import path_from_root from chia.util.ws_message import WsRpcMessage, create_payload_dict -from chia.wallet.cc_wallet.cc_wallet import CCWallet -from chia.wallet.derive_keys import master_sk_to_singleton_owner_sk +from chia.wallet.cat_wallet.cat_constants import DEFAULT_CATS +from chia.wallet.cat_wallet.cat_wallet import CATWallet +from chia.wallet.derive_keys import master_sk_to_singleton_owner_sk, master_sk_to_wallet_sk_unhardened from chia.wallet.rl_wallet.rl_wallet import RLWallet from chia.wallet.derive_keys import master_sk_to_farmer_sk, master_sk_to_pool_sk, master_sk_to_wallet_sk from chia.wallet.did_wallet.did_wallet import DIDWallet from chia.wallet.trade_record import TradeRecord from chia.wallet.transaction_record import TransactionRecord -from chia.wallet.util.backup_utils import download_backup, get_backup_info, upload_backup from chia.wallet.util.trade_utils import trade_record_to_dict from chia.wallet.util.transaction_type import TransactionType from chia.wallet.util.wallet_types import AmountWithPuzzlehash, WalletType @@ -47,12 +46,12 @@ class WalletRpcApi: assert wallet_node is not None self.service = wallet_node self.service_name = "chia_wallet" + self.balance_cache: Dict[int, Any] = {} def get_routes(self) -> Dict[str, Callable]: return { # Key management "/log_in": self.log_in, - "/get_logged_in_fingerprint": self.get_logged_in_fingerprint, "/get_public_keys": self.get_public_keys, "/get_private_key": self.get_private_key, "/generate_mnemonic": self.generate_mnemonic, @@ -75,25 +74,25 @@ class WalletRpcApi: "/get_wallet_balance": self.get_wallet_balance, "/get_transaction": self.get_transaction, "/get_transactions": self.get_transactions, + "/get_transaction_count": self.get_transaction_count, "/get_next_address": self.get_next_address, "/send_transaction": self.send_transaction, "/send_transaction_multi": self.send_transaction_multi, - "/create_backup": self.create_backup, - "/get_transaction_count": self.get_transaction_count, "/get_farmed_amount": self.get_farmed_amount, "/create_signed_transaction": self.create_signed_transaction, "/delete_unconfirmed_transactions": self.delete_unconfirmed_transactions, # Coloured coins and trading - "/cc_set_name": self.cc_set_name, - "/cc_get_name": self.cc_get_name, - "/cc_spend": self.cc_spend, - "/cc_get_colour": self.cc_get_colour, + "/cat_set_name": self.cat_set_name, + "/cat_get_name": self.cat_get_name, + "/cat_spend": self.cat_spend, + "/cat_get_asset_id": self.cat_get_asset_id, "/create_offer_for_ids": self.create_offer_for_ids, "/get_discrepancies_for_offer": self.get_discrepancies_for_offer, "/respond_to_offer": self.respond_to_offer, "/get_trade": self.get_trade, "/get_all_trades": self.get_all_trades, "/cancel_trade": self.cancel_trade, + "/get_cat_list": self.get_cat_list, # DID Wallet "/did_update_recovery_ids": self.did_update_recovery_ids, "/did_get_pubkey": self.did_get_pubkey, @@ -138,7 +137,9 @@ class WalletRpcApi: """ if self.service is not None: self.service._close() - await self.service._await_closed() + peers_close_task: Optional[asyncio.Task] = await self.service._await_closed() + if peers_close_task is not None: + await peers_close_task ########################################################################################## # Key management @@ -154,51 +155,12 @@ class WalletRpcApi: return {"fingerprint": fingerprint} await self._stop_wallet() - log_in_type = request["type"] - recovery_host = request["host"] - testing = False - - if "testing" in self.service.config and self.service.config["testing"] is True: - testing = True - if log_in_type == "skip": - started = await self.service._start(fingerprint=fingerprint, skip_backup_import=True) - elif log_in_type == "restore_backup": - file_path = Path(request["file_path"]) - started = await self.service._start(fingerprint=fingerprint, backup_file=file_path) - else: - started = await self.service._start(fingerprint) - + started = await self.service._start(fingerprint) if started is True: return {"fingerprint": fingerprint} - elif testing is True and self.service.backup_initialized is False: - response = {"success": False, "error": "not_initialized"} - return response - elif self.service.backup_initialized is False: - backup_info = None - backup_path = None - try: - private_key = await self.service.get_key_for_fingerprint(fingerprint) - last_recovery = await download_backup(recovery_host, private_key) - backup_path = path_from_root(self.service.root_path, "last_recovery") - if backup_path.exists(): - backup_path.unlink() - backup_path.write_text(last_recovery) - backup_info = get_backup_info(backup_path, private_key) - backup_info["backup_host"] = recovery_host - backup_info["downloaded"] = True - except Exception as e: - log.error(f"error {e}") - response = {"success": False, "error": "not_initialized"} - if backup_info is not None: - response["backup_info"] = backup_info - response["backup_path"] = f"{backup_path}" - return response return {"success": False, "error": "Unknown Error"} - async def get_logged_in_fingerprint(self, request: Dict): - return {"fingerprint": self.service.logged_in_fingerprint} - async def get_public_keys(self, request: Dict): try: assert self.service.keychain_proxy is not None # An offering to the mypy gods @@ -270,15 +232,7 @@ class WalletRpcApi: await self.service.keychain_proxy.check_keys(self.service.root_path) except Exception as e: log.error(f"Failed to check_keys after adding a new key: {e}") - request_type = request["type"] - if request_type == "new_wallet": - started = await self.service._start(fingerprint=fingerprint, new_wallet=True) - elif request_type == "skip": - started = await self.service._start(fingerprint=fingerprint, skip_backup_import=True) - elif request_type == "restore_backup": - file_path = Path(request["file_path"]) - started = await self.service._start(fingerprint=fingerprint, backup_file=file_path) - + started = await self.service._start(fingerprint=fingerprint) if started is True: return {"fingerprint": fingerprint} raise ValueError("Failed to start") @@ -322,12 +276,17 @@ class WalletRpcApi: if found_farmer and found_pool: break - ph = encode_puzzle_hash(create_puzzlehash_for_pk(master_sk_to_wallet_sk(sk, uint32(i)).get_g1()), prefix) - - if ph == farmer_target: - found_farmer = True - if ph == pool_target: - found_pool = True + phs = [ + encode_puzzle_hash(create_puzzlehash_for_pk(master_sk_to_wallet_sk(sk, uint32(i)).get_g1()), prefix), + encode_puzzle_hash( + create_puzzlehash_for_pk(master_sk_to_wallet_sk_unhardened(sk, uint32(i)).get_g1()), prefix + ), + ] + for ph in phs: + if ph == farmer_target: + found_farmer = True + if ph == pool_target: + found_pool = True return found_farmer, found_pool @@ -347,19 +306,18 @@ class WalletRpcApi: if self.service.logged_in_fingerprint != fingerprint: await self._stop_wallet() - await self.service._start(fingerprint=fingerprint, skip_backup_import=True) + await self.service._start(fingerprint=fingerprint) - async with self.service.wallet_state_manager.lock: - wallets: List[WalletInfo] = await self.service.wallet_state_manager.get_all_wallet_info_entries() - for w in wallets: - wallet = self.service.wallet_state_manager.wallets[w.id] - unspent = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(w.id) - balance = await wallet.get_confirmed_balance(unspent) - pending_balance = await wallet.get_unconfirmed_balance(unspent) + wallets: List[WalletInfo] = await self.service.wallet_state_manager.get_all_wallet_info_entries() + for w in wallets: + wallet = self.service.wallet_state_manager.wallets[w.id] + unspent = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(w.id) + balance = await wallet.get_confirmed_balance(unspent) + pending_balance = await wallet.get_unconfirmed_balance(unspent) - if (balance + pending_balance) > 0: - walletBalance = True - break + if (balance + pending_balance) > 0: + walletBalance = True + break return { "fingerprint": fingerprint, @@ -393,11 +351,8 @@ class WalletRpcApi: async def get_height_info(self, request: Dict): assert self.service.wallet_state_manager is not None - peak = self.service.wallet_state_manager.peak - if peak is None: - return {"height": 0} - else: - return {"height": peak.height} + height = self.service.wallet_state_manager.blockchain.get_peak_height() + return {"height": height} async def get_network_info(self, request: Dict): assert self.service.wallet_state_manager is not None @@ -424,53 +379,37 @@ class WalletRpcApi: return {"wallets": wallets} - async def _create_backup_and_upload(self, host) -> None: - assert self.service.wallet_state_manager is not None - try: - if "testing" in self.service.config and self.service.config["testing"] is True: - return None - now = time.time() - file_name = f"backup_{now}" - path = path_from_root(self.service.root_path, file_name) - await self.service.wallet_state_manager.create_wallet_backup(path) - backup_text = path.read_text() - response = await upload_backup(host, backup_text) - success = response["success"] - if success is False: - log.error("Failed to upload backup to wallet backup service") - elif success is True: - log.info("Finished upload of the backup file") - except Exception as e: - log.error(f"Exception in upload backup. Error: {e}") - async def create_new_wallet(self, request: Dict): assert self.service.wallet_state_manager is not None wallet_state_manager = self.service.wallet_state_manager + + if await self.service.wallet_state_manager.synced() is False: + raise ValueError("Wallet needs to be fully synced.") main_wallet = wallet_state_manager.main_wallet - host = request["host"] fee = uint64(request.get("fee", 0)) - if request["wallet_type"] == "cc_wallet": + if request["wallet_type"] == "cat_wallet": + name = request.get("name", "CAT Wallet") if request["mode"] == "new": async with self.service.wallet_state_manager.lock: - cc_wallet: CCWallet = await CCWallet.create_new_cc( - wallet_state_manager, main_wallet, uint64(request["amount"]) + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_state_manager, + main_wallet, + {"identifier": "genesis_by_id"}, + uint64(request["amount"]), + name, ) - colour = cc_wallet.get_colour() - asyncio.create_task(self._create_backup_and_upload(host)) - return { - "type": cc_wallet.type(), - "colour": colour, - "wallet_id": cc_wallet.id(), - } + asset_id = cat_wallet.get_asset_id() + self.service.wallet_state_manager.state_changed("wallet_created") + return {"type": cat_wallet.type(), "asset_id": asset_id, "wallet_id": cat_wallet.id()} elif request["mode"] == "existing": async with self.service.wallet_state_manager.lock: - cc_wallet = await CCWallet.create_wallet_for_cc( - wallet_state_manager, main_wallet, request["colour"] + cat_wallet = await CATWallet.create_wallet_for_cat( + wallet_state_manager, main_wallet, request["asset_id"] ) - asyncio.create_task(self._create_backup_and_upload(host)) - return {"type": cc_wallet.type()} + self.service.wallet_state_manager.state_changed("wallet_created") + return {"type": cat_wallet.type(), "asset_id": request["asset_id"], "wallet_id": cat_wallet.id()} else: # undefined mode pass @@ -487,7 +426,6 @@ class WalletRpcApi: uint64(int(request["amount"])), uint64(int(request["fee"])) if "fee" in request else uint64(0), ) - asyncio.create_task(self._create_backup_and_upload(host)) assert rl_admin.rl_info.admin_pubkey is not None return { "success": success, @@ -501,7 +439,6 @@ class WalletRpcApi: log.info("Create rl user wallet") async with self.service.wallet_state_manager.lock: rl_user: RLWallet = await RLWallet.create_rl_user(wallet_state_manager) - asyncio.create_task(self._create_backup_and_upload(host)) assert rl_user.rl_info.user_pubkey is not None return { "id": rl_user.id(), @@ -590,7 +527,7 @@ class WalletRpcApi: try: delayed_address = None if "p2_singleton_delayed_ph" in request: - delayed_address = bytes32.from_hexstr(request["p2_singleton_delayed_ph"]) + delayed_address = hexstr_to_bytes(request["p2_singleton_delayed_ph"]) tr, p2_singleton_puzzle_hash, launcher_id = await PoolWallet.create_new_pool_wallet_transaction( wallet_state_manager, main_wallet, @@ -623,28 +560,48 @@ class WalletRpcApi: assert self.service.wallet_state_manager is not None wallet_id = uint32(int(request["wallet_id"])) wallet = self.service.wallet_state_manager.wallets[wallet_id] - async with self.service.wallet_state_manager.lock: - unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id) - balance = await wallet.get_confirmed_balance(unspent_records) - pending_balance = await wallet.get_unconfirmed_balance(unspent_records) - spendable_balance = await wallet.get_spendable_balance(unspent_records) - pending_change = await wallet.get_pending_change_balance() - max_send_amount = await wallet.get_max_send_amount(unspent_records) - unconfirmed_removals: Dict[ - bytes32, Coin - ] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id) + # If syncing return the last available info or 0s + syncing = self.service.wallet_state_manager.sync_mode + if syncing: + if wallet_id in self.balance_cache: + wallet_balance = self.balance_cache[wallet_id] + else: + wallet_balance = { + "wallet_id": wallet_id, + "confirmed_wallet_balance": 0, + "unconfirmed_wallet_balance": 0, + "spendable_balance": 0, + "pending_change": 0, + "max_send_amount": 0, + "unspent_coin_count": 0, + "pending_coin_removal_count": 0, + } + else: + async with self.service.wallet_state_manager.lock: + unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet( + wallet_id + ) + balance = await wallet.get_confirmed_balance(unspent_records) + pending_balance = await wallet.get_unconfirmed_balance(unspent_records) + spendable_balance = await wallet.get_spendable_balance(unspent_records) + pending_change = await wallet.get_pending_change_balance() + max_send_amount = await wallet.get_max_send_amount(unspent_records) - wallet_balance = { - "wallet_id": wallet_id, - "confirmed_wallet_balance": balance, - "unconfirmed_wallet_balance": pending_balance, - "spendable_balance": spendable_balance, - "pending_change": pending_change, - "max_send_amount": max_send_amount, - "unspent_coin_count": len(unspent_records), - "pending_coin_removal_count": len(unconfirmed_removals), - } + unconfirmed_removals: Dict[ + bytes32, Coin + ] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id) + wallet_balance = { + "wallet_id": wallet_id, + "confirmed_wallet_balance": balance, + "unconfirmed_wallet_balance": pending_balance, + "spendable_balance": spendable_balance, + "pending_change": pending_change, + "max_send_amount": max_send_amount, + "unspent_coin_count": len(unspent_records), + "pending_coin_removal_count": len(unconfirmed_removals), + } + self.balance_cache[wallet_id] = wallet_balance return {"wallet_balance": wallet_balance} @@ -656,7 +613,7 @@ class WalletRpcApi: raise ValueError(f"Transaction 0x{transaction_id.hex()} not found") return { - "transaction": tr, + "transaction": tr.to_json_dict_convenience(self.service.config), "transaction_id": tr.name, } @@ -685,6 +642,16 @@ class WalletRpcApi: "wallet_id": wallet_id, } + async def get_transaction_count(self, request: Dict) -> Dict: + assert self.service.wallet_state_manager is not None + + wallet_id = int(request["wallet_id"]) + count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet(wallet_id) + return { + "count": count, + "wallet_id": wallet_id, + } + # this function is just here for backwards-compatibility. It will probably # be removed in the future async def get_initial_freeze_period(self, _: Dict): @@ -708,8 +675,8 @@ class WalletRpcApi: if wallet.type() == WalletType.STANDARD_WALLET: raw_puzzle_hash = await wallet.get_puzzle_hash(create_new) address = encode_puzzle_hash(raw_puzzle_hash, prefix) - elif wallet.type() == WalletType.COLOURED_COIN: - raw_puzzle_hash = await wallet.get_puzzle_hash(create_new) + elif wallet.type() == WalletType.CAT: + raw_puzzle_hash = await wallet.standard_wallet.get_puzzle_hash(create_new) address = encode_puzzle_hash(raw_puzzle_hash, prefix) else: raise ValueError(f"Wallet type {wallet.type()} cannot create puzzle hashes") @@ -728,25 +695,33 @@ class WalletRpcApi: wallet_id = int(request["wallet_id"]) wallet = self.service.wallet_state_manager.wallets[wallet_id] + if wallet.type() == WalletType.CAT: + raise ValueError("send_transaction does not work for CAT wallets") + if not isinstance(request["amount"], int) or not isinstance(request["fee"], int): raise ValueError("An integer amount or fee is required (too many decimals)") amount: uint64 = uint64(request["amount"]) puzzle_hash: bytes32 = decode_puzzle_hash(request["address"]) + + memos: Optional[bytes] = None + if "memos" in request: + memos = [mem.encode("utf-8") for mem in request["memos"]] + if "fee" in request: fee = uint64(request["fee"]) else: fee = uint64(0) async with self.service.wallet_state_manager.lock: - tx: TransactionRecord = await wallet.generate_signed_transaction(amount, puzzle_hash, fee) + tx: TransactionRecord = await wallet.generate_signed_transaction(amount, puzzle_hash, fee, memos=memos) await wallet.push_transaction(tx) # Transaction may not have been included in the mempool yet. Use get_transaction to check. return { - "transaction": tx, + "transaction": tx.to_json_dict_convenience(self.service.config), "transaction_id": tx.name, } - async def send_transaction_multi(self, request): + async def send_transaction_multi(self, request) -> Dict: assert self.service.wallet_state_manager is not None if await self.service.wallet_state_manager.synced() is False: @@ -756,21 +731,20 @@ class WalletRpcApi: wallet = self.service.wallet_state_manager.wallets[wallet_id] async with self.service.wallet_state_manager.lock: - transaction: TransactionRecord = (await self.create_signed_transaction(request, hold_lock=False))[ - "signed_tx" - ] - await wallet.push_transaction(transaction) + transaction: Dict = (await self.create_signed_transaction(request, hold_lock=False))["signed_tx"] + tr: TransactionRecord = TransactionRecord.from_json_dict_convenience(transaction) + await wallet.push_transaction(tr) # Transaction may not have been included in the mempool yet. Use get_transaction to check. - return { - "transaction": transaction, - "transaction_id": transaction.name, - } + return {"transaction": transaction, "transaction_id": tr.name} async def delete_unconfirmed_transactions(self, request): wallet_id = uint32(request["wallet_id"]) if wallet_id not in self.service.wallet_state_manager.wallets: raise ValueError(f"Wallet id {wallet_id} does not exist") + if await self.service.wallet_state_manager.synced() is False: + raise ValueError("Wallet needs to be fully synced.") + async with self.service.wallet_state_manager.lock: async with self.service.wallet_state_manager.tx_store.db_wrapper.lock: await self.service.wallet_state_manager.tx_store.db_wrapper.begin_transaction() @@ -782,41 +756,40 @@ class WalletRpcApi: await self.service.wallet_state_manager.tx_store.rebuild_tx_cache() return {} - async def get_transaction_count(self, request): - wallet_id = int(request["wallet_id"]) - count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet(wallet_id) - return {"wallet_id": wallet_id, "count": count} - - async def create_backup(self, request): - assert self.service.wallet_state_manager is not None - file_path = Path(request["file_path"]) - await self.service.wallet_state_manager.create_wallet_backup(file_path) - return {} - ########################################################################################## # Coloured Coins and Trading ########################################################################################## - async def cc_set_name(self, request): + async def get_cat_list(self, request): + return {"cat_list": list(DEFAULT_CATS.values())} + + async def cat_set_name(self, request): assert self.service.wallet_state_manager is not None wallet_id = int(request["wallet_id"]) - wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id] + wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id] await wallet.set_name(str(request["name"])) return {"wallet_id": wallet_id} - async def cc_get_name(self, request): + async def cat_get_name(self, request): assert self.service.wallet_state_manager is not None wallet_id = int(request["wallet_id"]) - wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id] + wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id] name: str = await wallet.get_name() return {"wallet_id": wallet_id, "name": name} - async def cc_spend(self, request): + async def cat_spend(self, request): assert self.service.wallet_state_manager is not None + + if await self.service.wallet_state_manager.synced() is False: + raise ValueError("Wallet needs to be fully synced.") wallet_id = int(request["wallet_id"]) - wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id] + wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id] + puzzle_hash: bytes32 = decode_puzzle_hash(request["inner_address"]) + memos: List[bytes] = [] + if "memos" in request: + memos = [mem.encode("utf-8") for mem in request["memos"]] if not isinstance(request["amount"], int) or not isinstance(request["amount"], int): raise ValueError("An integer amount or fee is required (too many decimals)") amount: uint64 = uint64(request["amount"]) @@ -825,20 +798,23 @@ class WalletRpcApi: else: fee = uint64(0) async with self.service.wallet_state_manager.lock: - tx: TransactionRecord = await wallet.generate_signed_transaction([amount], [puzzle_hash], fee) - await wallet.push_transaction(tx) + txs: TransactionRecord = await wallet.generate_signed_transaction( + [amount], [puzzle_hash], fee, memos=[memos] + ) + for tx in txs: + await wallet.standard_wallet.push_transaction(tx) return { - "transaction": tx, + "transaction": tx.to_json_dict_convenience(self.service.config), "transaction_id": tx.name, } - async def cc_get_colour(self, request): + async def cat_get_asset_id(self, request): assert self.service.wallet_state_manager is not None wallet_id = int(request["wallet_id"]) - wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id] - colour: str = wallet.get_colour() - return {"colour": colour, "wallet_id": wallet_id} + wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id] + asset_id: str = wallet.get_asset_id() + return {"asset_id": asset_id, "wallet_id": wallet_id} async def create_offer_for_ids(self, request): assert self.service.wallet_state_manager is not None @@ -889,7 +865,7 @@ class WalletRpcApi: trade_mgr = self.service.wallet_state_manager.trade_manager - trade_id = bytes32.from_hexstr(request["trade_id"]) + trade_id = hexstr_to_bytes(request["trade_id"]) trade: Optional[TradeRecord] = await trade_mgr.get_trade_by_id(trade_id) if trade is None: raise ValueError(f"No trade with trade id: {trade_id.hex()}") @@ -914,7 +890,7 @@ class WalletRpcApi: wsm = self.service.wallet_state_manager secure = request["secure"] - trade_id = bytes32.from_hexstr(request["trade_id"]) + trade_id = hexstr_to_bytes(request["trade_id"]) async with self.service.wallet_state_manager.lock: if secure: @@ -923,31 +899,6 @@ class WalletRpcApi: await wsm.trade_manager.cancel_pending_offer(trade_id) return {} - async def get_backup_info(self, request: Dict): - file_path = Path(request["file_path"]) - sk = None - if "words" in request: - mnemonic = request["words"] - passphrase = "" - try: - assert self.service.keychain_proxy is not None # An offering to the mypy gods - sk = await self.service.keychain_proxy.add_private_key(" ".join(mnemonic), passphrase) - except KeyError as e: - return { - "success": False, - "error": f"The word '{e.args[0]}' is incorrect.'", - "word": e.args[0], - } - except Exception as e: - return {"success": False, "error": str(e)} - elif "fingerprint" in request: - sk, seed = await self._get_private_key(request["fingerprint"]) - - if sk is None: - raise ValueError("Unable to decrypt the backup file.") - backup_info = get_backup_info(file_path, sk) - return {"backup_info": backup_info} - ########################################################################################## # Distributed Identities ########################################################################################## @@ -1156,7 +1107,8 @@ class WalletRpcApi: "last_height_farmed": last_height_farmed, } - async def create_signed_transaction(self, request, hold_lock=True): + async def create_signed_transaction(self, request, hold_lock=True) -> Dict: + assert self.service.wallet_state_manager is not None if "additions" not in request or len(request["additions"]) < 1: raise ValueError("Specify additions list") @@ -1165,17 +1117,20 @@ class WalletRpcApi: assert amount_0 <= self.service.constants.MAX_COIN_AMOUNT puzzle_hash_0 = hexstr_to_bytes(additions[0]["puzzle_hash"]) if len(puzzle_hash_0) != 32: - raise ValueError(f"Address must be 32 bytes. {puzzle_hash_0}") + raise ValueError(f"Address must be 32 bytes. {puzzle_hash_0.hex()}") + + memos_0 = None if "memos" not in additions[0] else [mem.encode("utf-8") for mem in additions[0]["memos"]] additional_outputs: List[AmountWithPuzzlehash] = [] for addition in additions[1:]: receiver_ph = hexstr_to_bytes(addition["puzzle_hash"]) if len(receiver_ph) != 32: - raise ValueError(f"Address must be 32 bytes. {receiver_ph}") + raise ValueError(f"Address must be 32 bytes. {receiver_ph.hex()}") amount = uint64(addition["amount"]) if amount > self.service.constants.MAX_COIN_AMOUNT: raise ValueError(f"Coin amount cannot exceed {self.service.constants.MAX_COIN_AMOUNT}") - additional_outputs.append({"puzzlehash": receiver_ph, "amount": amount}) + memos = None if "memos" not in addition else [mem.encode("utf-8") for mem in addition["memos"]] + additional_outputs.append({"puzzlehash": receiver_ph, "amount": amount, "memos": memos}) fee = uint64(0) if "fee" in request: @@ -1203,6 +1158,7 @@ class WalletRpcApi: ignore_max_send_amount=True, primaries=additional_outputs, announcements_to_consume=coin_announcements, + memos=memos_0, ) else: signed_tx = await self.service.wallet_state_manager.main_wallet.generate_signed_transaction( @@ -1213,8 +1169,9 @@ class WalletRpcApi: ignore_max_send_amount=True, primaries=additional_outputs, announcements_to_consume=coin_announcements, + memos=memos_0, ) - return {"signed_tx": signed_tx} + return {"signed_tx": signed_tx.to_json_dict_convenience(self.service.config)} ########################################################################################## # Pool Wallet @@ -1228,14 +1185,15 @@ class WalletRpcApi: pool_wallet_info: PoolWalletInfo = await wallet.get_current_state() owner_pubkey = pool_wallet_info.current.owner_pubkey target_puzzlehash = None + + if await self.service.wallet_state_manager.synced() is False: + raise ValueError("Wallet needs to be fully synced.") + if "target_puzzlehash" in request: target_puzzlehash = bytes32(hexstr_to_bytes(request["target_puzzlehash"])) - # TODO: address hint error and remove ignore - # error: Argument 2 to "create_pool_state" has incompatible type "Optional[bytes32]"; expected "bytes32" - # [arg-type] new_target_state: PoolState = create_pool_state( FARMING_TO_POOL, - target_puzzlehash, # type: ignore[arg-type] + target_puzzlehash, owner_pubkey, request["pool_url"], uint32(request["relative_lock_height"]), @@ -1254,6 +1212,9 @@ class WalletRpcApi: wallet_id = uint32(request["wallet_id"]) wallet: PoolWallet = self.service.wallet_state_manager.wallets[wallet_id] + if await self.service.wallet_state_manager.synced() is False: + raise ValueError("Wallet needs to be fully synced.") + async with self.service.wallet_state_manager.lock: total_fee, tx = await wallet.self_pool(fee) # total_fee: uint64, tx: TransactionRecord return {"total_fee": total_fee, "transaction": tx} diff --git a/chia/server/server.py b/chia/server/server.py index ee47b31a13..f02f9705d7 100644 --- a/chia/server/server.py +++ b/chia/server/server.py @@ -787,13 +787,9 @@ class ChiaServer: def is_trusted_peer(self, peer: WSChiaConnection, trusted_peers: Dict) -> bool: if trusted_peers is None: return False - for trusted_peer in trusted_peers: - cert = self.root_path / trusted_peers[trusted_peer] - pem_cert = x509.load_pem_x509_certificate(cert.read_bytes()) - cert_bytes = pem_cert.public_bytes(encoding=serialization.Encoding.DER) - der_cert = x509.load_der_x509_certificate(cert_bytes) - peer_id = bytes32(der_cert.fingerprint(hashes.SHA256())) - if peer_id == peer.peer_node_id: - self.log.debug(f"trusted node {peer.peer_node_id} {peer.peer_host}") - return True - return False + if not self.config["testing"] and peer.peer_host == "127.0.0.1": + return True + if peer.peer_node_id.hex() not in trusted_peers: + return False + + return True diff --git a/chia/server/ws_connection.py b/chia/server/ws_connection.py index a58e43ecda..10f3894ca6 100644 --- a/chia/server/ws_connection.py +++ b/chia/server/ws_connection.py @@ -106,8 +106,9 @@ class WSChiaConnection: self.outbound_rate_limiter = RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent) self.inbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent) - # Used by the Chia Seeder. + # Used by crawler/dns introducer self.version = None + self.protocol_version = "" async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType): if self.is_outbound: @@ -142,7 +143,7 @@ class WSChiaConnection: raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID) self.version = inbound_handshake.software_version - + self.protocol_version = inbound_handshake.protocol_version self.peer_server_port = inbound_handshake.server_port self.connection_type = NodeType(inbound_handshake.node_type) @@ -188,7 +189,7 @@ class WSChiaConnection: async def close(self, ban_time: int = 0, ws_close_code: WSCloseCode = WSCloseCode.OK, error: Optional[Err] = None): """ - Closes the connection, and finally calls the close_callback on the server, so the connection gets removed + Closes the connection, and finally calls the close_callback on the server, so the connections gets removed from the global list. """ @@ -275,12 +276,11 @@ class WSChiaConnection: self.log.error(f"Exception: {e}") self.log.error(f"Exception Stack: {error_stack}") - async def send_message(self, message: Message) -> bool: + async def send_message(self, message: Message): """Send message sends a message with no tracking / callback.""" if self.closed: - return False + return None await self.outgoing_queue.put(message) - return True def __getattr__(self, attr_name: str): # TODO KWARGS @@ -343,10 +343,7 @@ class WSChiaConnection: message = Message(message_no_id.type, request_id, message_no_id.data) - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[uint16]" for "Dict[bytes32, Event]"; expected type "bytes32" - # [index] - self.pending_requests[message.id] = event # type: ignore[index] + self.pending_requests[message.id] = event await self.outgoing_queue.put(message) # If the timeout passes, we set the event @@ -361,37 +358,24 @@ class WSChiaConnection: raise timeout_task = asyncio.create_task(time_out(message.id, timeout)) - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[uint16]" for "Dict[bytes32, Task[Any]]"; expected type "bytes32" - # [index] - self.pending_timeouts[message.id] = timeout_task # type: ignore[index] + self.pending_timeouts[message.id] = timeout_task await event.wait() - # TODO: address hint error and remove ignore - # error: No overload variant of "pop" of "MutableMapping" matches argument type "Optional[uint16]" - # [call-overload] - # note: Possible overload variants: - # note: def pop(self, key: bytes32) -> Event - # note: def [_T] pop(self, key: bytes32, default: Union[Event, _T] = ...) -> Union[Event, _T] - self.pending_requests.pop(message.id) # type: ignore[call-overload] + self.pending_requests.pop(message.id) result: Optional[Message] = None if message.id in self.request_results: - # TODO: address hint error and remove ignore - # error: Invalid index type "Optional[uint16]" for "Dict[bytes32, Message]"; expected type "bytes32" - # [index] - result = self.request_results[message.id] # type: ignore[index] + result = self.request_results[message.id] assert result is not None self.log.debug(f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}") - # TODO: address hint error and remove ignore - # error: No overload variant of "pop" of "MutableMapping" matches argument type "Optional[uint16]" - # [call-overload] - # note: Possible overload variants: - # note: def pop(self, key: bytes32) -> Message - # note: def [_T] pop(self, key: bytes32, default: Union[Message, _T] = ...) -> Union[Message, _T] - self.request_results.pop(result.id) # type: ignore[call-overload] + self.request_results.pop(result.id) return result + async def reply_to_request(self, response: Message): + if self.closed: + return None + await self.outgoing_queue.put(response) + async def send_messages(self, messages: List[Message]): if self.closed: return None @@ -508,7 +492,7 @@ class WSChiaConnection: await asyncio.sleep(3) return None - # Used by the Chia Seeder. + # Used by crawler/dns introducer def get_version(self): return self.version diff --git a/chia/types/announcement.py b/chia/types/announcement.py index f12d231446..a224f9e54b 100644 --- a/chia/types/announcement.py +++ b/chia/types/announcement.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.hash import std_hash @@ -8,9 +9,14 @@ from chia.util.hash import std_hash class Announcement: origin_info: bytes32 message: bytes + morph_bytes: Optional[bytes] = None # CATs morph their announcements and other puzzles may choose to do so too def name(self) -> bytes32: - return std_hash(bytes(self.origin_info + self.message)) + if self.morph_bytes is not None: + message = std_hash(self.morph_bytes + self.message) + else: + message = self.message + return std_hash(bytes(self.origin_info + message)) def __str__(self): return self.name().decode("utf-8") diff --git a/chia/types/coin_spend.py b/chia/types/coin_spend.py index ae15a08fd3..64f7d7ecac 100644 --- a/chia/types/coin_spend.py +++ b/chia/types/coin_spend.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from typing import List +from blspy import G2Element from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.program import SerializedProgram, INFINITE_COST +from chia.types.condition_opcodes import ConditionOpcode from chia.util.chain_utils import additions_for_solution, fee_for_solution from chia.util.streamable import Streamable, streamable @@ -25,3 +27,26 @@ class CoinSpend(Streamable): def reserved_fee(self) -> int: return fee_for_solution(self.puzzle_reveal, self.solution, INFINITE_COST) + + def hints(self) -> List[bytes]: + # import above was causing circular import issue + from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions + from chia.consensus.default_constants import DEFAULT_CONSTANTS + from chia.types.spend_bundle import SpendBundle + from chia.full_node.bundle_tools import simple_solution_generator + + bundle = SpendBundle([self], G2Element()) + generator = simple_solution_generator(bundle) + + npc_result = get_name_puzzle_conditions( + generator, INFINITE_COST, cost_per_byte=DEFAULT_CONSTANTS.COST_PER_BYTE, safe_mode=False + ) + h_list = [] + for npc in npc_result.npc_list: + for opcode, conditions in npc.conditions: + if opcode == ConditionOpcode.CREATE_COIN: + for condition in conditions: + if len(condition.vars) > 2 and condition.vars[2] != b"": + h_list.append(condition.vars[2]) + + return h_list diff --git a/chia/types/spend_bundle.py b/chia/types/spend_bundle.py index f8e9977cce..9dacec241d 100644 --- a/chia/types/spend_bundle.py +++ b/chia/types/spend_bundle.py @@ -2,17 +2,20 @@ import dataclasses import warnings from dataclasses import dataclass -from typing import List +from typing import List, Dict from blspy import AugSchemeMPL, G2Element +from clvm.casts import int_from_bytes from chia.consensus.default_constants import DEFAULT_CONSTANTS from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.streamable import Streamable, dataclass_from_dict, recurse_jsonify, streamable from chia.wallet.util.debug_spend_bundle import debug_spend_bundle +from .blockchain_format.program import Program from .coin_spend import CoinSpend +from .condition_opcodes import ConditionOpcode @dataclass(frozen=True) @@ -77,6 +80,27 @@ class SpendBundle(Streamable): return result + def get_memos(self) -> Dict[bytes32, List[bytes]]: + """ + Retrieves the memos for additions in this spend_bundle, which are formatted as a list in the 3rd parameter of + CREATE_COIN. If there are no memos, the addition coin_id is not included. If they are not formatted as a list + of bytes, they are not included. This is expensive to call, it should not be used in full node code. + """ + memos: Dict[bytes32, List[bytes]] = {} + for coin_spend in self.coin_spends: + result = Program.from_bytes(bytes(coin_spend.puzzle_reveal)).run( + Program.from_bytes(bytes(coin_spend.solution)) + ) + for condition in result.as_python(): + if condition[0] == ConditionOpcode.CREATE_COIN and len(condition) >= 4: + # If only 3 elements (opcode + 2 args), there is no memo, this is ph, amount + coin_added = Coin(coin_spend.coin.name(), bytes32(condition[1]), int_from_bytes(condition[2])) + if type(condition[3]) != list: + # If it's not a list, it's not the correct format + continue + memos[coin_added.name()] = condition[3] + return memos + # Note that `coin_spends` used to have the bad name `coin_solutions`. # Some API still expects this name. For now, we accept both names. # diff --git a/chia/util/initial-config.yaml b/chia/util/initial-config.yaml index 81b3272c7d..3a5c9e6e09 100644 --- a/chia/util/initial-config.yaml +++ b/chia/util/initial-config.yaml @@ -73,6 +73,7 @@ network_overrides: &network_overrides default_full_node_port: 8444 testnet0: address_prefix: "txch" + default_full_node_port: 58444 testnet1: address_prefix: "txch" testnet2: diff --git a/chia/wallet/wallet_blockchain.py b/chia/wallet/wallet_blockchain.py index 184de48f8b..41df87c62c 100644 --- a/chia/wallet/wallet_blockchain.py +++ b/chia/wallet/wallet_blockchain.py @@ -114,7 +114,7 @@ class WalletBlockchain(BlockchainInterface): return ReceiveBlockResult.INVALID_BLOCK, Err.INVALID_POSPACE block_record: BlockRecord = block_to_block_record( - self.constants, self, required_iters, None, block, sub_slot_iters + self.constants, self, required_iters, None, block ) self.add_block_record(block_record) if self._peak is None: diff --git a/setup.py b/setup.py index 1ac9718dae..ac77412359 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ dependencies = [ "dnslib==0.9.14", # dns lib "typing-extensions==4.0.1", # typing backports like Protocol and TypedDict "zstd==1.5.0.4", + "packaging==21.0" ] upnp_dependencies = [ diff --git a/tests/clvm/benchmark_costs.py b/tests/clvm/benchmark_costs.py new file mode 100644 index 0000000000..45354506cd --- /dev/null +++ b/tests/clvm/benchmark_costs.py @@ -0,0 +1,16 @@ +from chia.types.blockchain_format.program import INFINITE_COST +from chia.types.spend_bundle import SpendBundle +from chia.types.generator_types import BlockGenerator +from chia.consensus.cost_calculator import calculate_cost_of_program, NPCResult +from chia.consensus.default_constants import DEFAULT_CONSTANTS +from chia.full_node.bundle_tools import simple_solution_generator +from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions + + +def cost_of_spend_bundle(spend_bundle: SpendBundle) -> int: + program: BlockGenerator = simple_solution_generator(spend_bundle) + npc_result: NPCResult = get_name_puzzle_conditions( + program, INFINITE_COST, cost_per_byte=DEFAULT_CONSTANTS.COST_PER_BYTE, safe_mode=True + ) + cost: int = calculate_cost_of_program(program.program, npc_result, DEFAULT_CONSTANTS.COST_PER_BYTE) + return cost diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index f447b69937..fae610c81f 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -178,7 +178,7 @@ async def setup_wallet_node( service = Service(**kwargs) - await service.start(new_wallet=True) + await service.start() yield service._node, service._node.server diff --git a/tests/wallet/cc_wallet/__init__.py b/tests/wallet/cat_wallet/__init__.py similarity index 100% rename from tests/wallet/cc_wallet/__init__.py rename to tests/wallet/cat_wallet/__init__.py diff --git a/tests/wallet/cc_wallet/test_trades.py b/tests/wallet/cat_wallet/dont_test_trades.py similarity index 71% rename from tests/wallet/cc_wallet/test_trades.py rename to tests/wallet/cat_wallet/dont_test_trades.py index 4510490b84..b607fbfe10 100644 --- a/tests/wallet/cc_wallet/test_trades.py +++ b/tests/wallet/cat_wallet/dont_test_trades.py @@ -7,7 +7,7 @@ import pytest from chia.simulator.simulator_protocol import FarmNewBlockProtocol from chia.types.peer_info import PeerInfo from chia.util.ints import uint16, uint64 -from chia.wallet.cc_wallet.cc_wallet import CCWallet +from chia.wallet.cat_wallet.cat_wallet import CATWallet from chia.wallet.trade_manager import TradeManager from chia.wallet.trading.trade_status import TradeStatus from tests.setup_nodes import setup_simulators_and_wallets @@ -63,40 +63,43 @@ async def wallets_prefarm(two_wallet_nodes): return wallet_node_0, wallet_node_1, full_node_api -class TestCCTrades: +class TestCATTrades: @pytest.mark.asyncio - async def test_cc_trade(self, wallets_prefarm): + async def test_cat_trade(self, wallets_prefarm): wallet_node_0, wallet_node_1, full_node = wallets_prefarm wallet_0 = wallet_node_0.wallet_state_manager.main_wallet wallet_1 = wallet_node_1.wallet_state_manager.main_wallet - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node_0.wallet_state_manager, wallet_0, uint64(100)) + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node_0.wallet_state_manager, wallet_0, {"identifier": "genesis_by_id"}, uint64(100) + ) await asyncio.sleep(1) for i in range(1, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) await time_out_assert(15, wallet_height_at_least, True, wallet_node_0, 27) - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) - assert cc_wallet.cc_info.my_genesis_checker is not None - colour = cc_wallet.get_colour() + assert cat_wallet.cat_info.my_tail is not None + asset_id = cat_wallet.get_asset_id() - cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc( - wallet_node_1.wallet_state_manager, wallet_1, colour + cat_wallet_2: CATWallet = await CATWallet.create_wallet_for_cat( + wallet_node_1.wallet_state_manager, wallet_1, asset_id ) await asyncio.sleep(1) - assert cc_wallet.cc_info.my_genesis_checker == cc_wallet_2.cc_info.my_genesis_checker + assert cat_wallet.cat_info.my_tail == cat_wallet_2.cat_info.my_tail for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) await time_out_assert(15, wallet_height_at_least, True, wallet_node_0, 31) - # send cc_wallet 2 a coin - cc_hash = await cc_wallet_2.get_new_inner_hash() - tx_record = await cc_wallet.generate_signed_transaction([uint64(1)], [cc_hash]) - await wallet_0.wallet_state_manager.add_pending_transaction(tx_record) - await asyncio.sleep(1) + # send cat_wallet 2 a coin + cat_hash = await cat_wallet_2.get_new_inner_hash() + tx_records = await cat_wallet.generate_signed_transaction([uint64(1)], [cat_hash]) + for tx_record in tx_records: + await wallet_0.wallet_state_manager.add_pending_transaction(tx_record) + await asyncio.sleep(1) for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) await time_out_assert(15, wallet_height_at_least, True, wallet_node_0, 35) @@ -126,7 +129,7 @@ class TestCCTrades: assert offer is not None assert offer["chia"] == -10 - assert offer[colour] == 30 + assert offer[asset_id] == 30 success, trade, reason = await trade_manager_1.respond_to_offer(file_path) await asyncio.sleep(1) @@ -137,35 +140,37 @@ class TestCCTrades: await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) await time_out_assert(15, wallet_height_at_least, True, wallet_node_0, 39) - await time_out_assert(15, cc_wallet_2.get_confirmed_balance, 31) - await time_out_assert(15, cc_wallet_2.get_unconfirmed_balance, 31) + await time_out_assert(15, cat_wallet_2.get_confirmed_balance, 31) + await time_out_assert(15, cat_wallet_2.get_unconfirmed_balance, 31) trade_2 = await trade_manager_0.get_trade_by_id(trade_offer.trade_id) assert TradeStatus(trade_2.status) is TradeStatus.CONFIRMED @pytest.mark.asyncio - async def test_cc_trade_accept_with_zero(self, wallets_prefarm): + async def test_cat_trade_accept_with_zero(self, wallets_prefarm): wallet_node_0, wallet_node_1, full_node = wallets_prefarm wallet_0 = wallet_node_0.wallet_state_manager.main_wallet wallet_1 = wallet_node_1.wallet_state_manager.main_wallet - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node_0.wallet_state_manager, wallet_0, uint64(100)) + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node_0.wallet_state_manager, wallet_0, {"identifier": "genesis_by_id"}, uint64(100) + ) await asyncio.sleep(1) for i in range(1, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) - assert cc_wallet.cc_info.my_genesis_checker is not None - colour = cc_wallet.get_colour() + assert cat_wallet.cat_info.my_tail is not None + asset_id = cat_wallet.get_asset_id() - cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc( - wallet_node_1.wallet_state_manager, wallet_1, colour + cat_wallet_2: CATWallet = await CATWallet.create_wallet_for_cat( + wallet_node_1.wallet_state_manager, wallet_1, asset_id ) await asyncio.sleep(1) - assert cc_wallet.cc_info.my_genesis_checker == cc_wallet_2.cc_info.my_genesis_checker + assert cat_wallet.cat_info.my_tail == cat_wallet_2.cat_info.my_tail ph = await wallet_1.get_new_puzzlehash() for i in range(0, buffer_blocks): @@ -195,10 +200,10 @@ class TestCCTrades: assert success is True assert offer is not None - assert cc_wallet.get_colour() == cc_wallet_2.get_colour() + assert cat_wallet.get_asset_id() == cat_wallet_2.get_asset_id() assert offer["chia"] == -10 - assert offer[colour] == 30 + assert offer[asset_id] == 30 success, trade, reason = await trade_manager_1.respond_to_offer(file_path) await asyncio.sleep(1) @@ -208,14 +213,14 @@ class TestCCTrades: for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - await time_out_assert(15, cc_wallet_2.get_confirmed_balance, 30) - await time_out_assert(15, cc_wallet_2.get_unconfirmed_balance, 30) + await time_out_assert(15, cat_wallet_2.get_confirmed_balance, 30) + await time_out_assert(15, cat_wallet_2.get_unconfirmed_balance, 30) trade_2 = await trade_manager_0.get_trade_by_id(trade_offer.trade_id) assert TradeStatus(trade_2.status) is TradeStatus.CONFIRMED @pytest.mark.asyncio - async def test_cc_trade_with_multiple_colours(self, wallets_prefarm): - # This test start with CCWallet in both wallets. wall + async def test_cat_trade_with_multiple_asset_ids(self, wallets_prefarm): + # This test start with CATWallet in both wallets. wall # wallet1 {wallet_id: 2 = 70} # wallet2 {wallet_id: 2 = 30} @@ -223,33 +228,35 @@ class TestCCTrades: wallet_a = wallet_node_a.wallet_state_manager.main_wallet wallet_b = wallet_node_b.wallet_state_manager.main_wallet - # cc_a_2 = coloured coin, Alice, wallet id = 2 - cc_a_2 = wallet_node_a.wallet_state_manager.wallets[2] - cc_b_2 = wallet_node_b.wallet_state_manager.wallets[2] + # cat_a_2 = CAT, Alice, wallet id = 2 + cat_a_2 = wallet_node_a.wallet_state_manager.wallets[2] + cat_b_2 = wallet_node_b.wallet_state_manager.wallets[2] - cc_a_3: CCWallet = await CCWallet.create_new_cc(wallet_node_a.wallet_state_manager, wallet_a, uint64(100)) + cat_a_3: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node_a.wallet_state_manager, wallet_a, {"identifier": "genesis_by_id"}, uint64(100) + ) await asyncio.sleep(1) for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - await time_out_assert(15, cc_a_3.get_confirmed_balance, 100) - await time_out_assert(15, cc_a_3.get_unconfirmed_balance, 100) + await time_out_assert(15, cat_a_3.get_confirmed_balance, 100) + await time_out_assert(15, cat_a_3.get_unconfirmed_balance, 100) # store these for asserting change later - cc_balance = await cc_a_2.get_unconfirmed_balance() - cc_balance_2 = await cc_b_2.get_unconfirmed_balance() + cat_balance = await cat_a_2.get_unconfirmed_balance() + cat_balance_2 = await cat_b_2.get_unconfirmed_balance() - assert cc_a_3.cc_info.my_genesis_checker is not None - red = cc_a_3.get_colour() + assert cat_a_3.cat_info.my_tail is not None + red = cat_a_3.get_asset_id() for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - cc_b_3: CCWallet = await CCWallet.create_wallet_for_cc(wallet_node_b.wallet_state_manager, wallet_b, red) + cat_b_3: CATWallet = await CATWallet.create_wallet_for_cat(wallet_node_b.wallet_state_manager, wallet_b, red) await asyncio.sleep(1) - assert cc_a_3.cc_info.my_genesis_checker == cc_b_3.cc_info.my_genesis_checker + assert cat_a_3.cat_info.my_tail == cat_b_3.cat_info.my_tail for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) @@ -279,11 +286,11 @@ class TestCCTrades: assert offer is not None assert offer["chia"] == -1000 - colour_2 = cc_a_2.get_colour() - colour_3 = cc_a_3.get_colour() + asset_id_2 = cat_a_2.get_asset_id() + asset_id_3 = cat_a_3.get_asset_id() - assert offer[colour_2] == 20 - assert offer[colour_3] == 50 + assert offer[asset_id_2] == 20 + assert offer[asset_id_3] == 50 success, trade, reason = await trade_manager_1.respond_to_offer(file_path) await asyncio.sleep(1) @@ -292,14 +299,14 @@ class TestCCTrades: for i in range(0, 10): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - await time_out_assert(15, cc_b_3.get_confirmed_balance, 50) - await time_out_assert(15, cc_b_3.get_unconfirmed_balance, 50) + await time_out_assert(15, cat_b_3.get_confirmed_balance, 50) + await time_out_assert(15, cat_b_3.get_unconfirmed_balance, 50) - await time_out_assert(15, cc_a_3.get_confirmed_balance, 50) - await time_out_assert(15, cc_a_3.get_unconfirmed_balance, 50) + await time_out_assert(15, cat_a_3.get_confirmed_balance, 50) + await time_out_assert(15, cat_a_3.get_unconfirmed_balance, 50) - await time_out_assert(15, cc_a_2.get_unconfirmed_balance, cc_balance - offer[colour_2]) - await time_out_assert(15, cc_b_2.get_unconfirmed_balance, cc_balance_2 + offer[colour_2]) + await time_out_assert(15, cat_a_2.get_unconfirmed_balance, cat_balance - offer[asset_id_2]) + await time_out_assert(15, cat_b_2.get_unconfirmed_balance, cat_balance_2 + offer[asset_id_2]) trade = await trade_manager_0.get_trade_by_id(trade_offer.trade_id) @@ -310,10 +317,10 @@ class TestCCTrades: @pytest.mark.asyncio async def test_create_offer_with_zero_val(self, wallets_prefarm): # Wallet A Wallet B - # CCWallet id 2: 50 CCWallet id 2: 50 - # CCWallet id 3: 50 CCWallet id 2: 50 + # CATWallet id 2: 50 CATWallet id 2: 50 + # CATWallet id 3: 50 CATWallet id 2: 50 # Wallet A will - # Wallet A will create a new CC and wallet B will create offer to buy that coin + # Wallet A will create a new CAT and wallet B will create offer to buy that coin wallet_node_a, wallet_node_b, full_node = wallets_prefarm wallet_a = wallet_node_a.wallet_state_manager.main_wallet @@ -321,20 +328,22 @@ class TestCCTrades: trade_manager_a: TradeManager = wallet_node_a.wallet_state_manager.trade_manager trade_manager_b: TradeManager = wallet_node_b.wallet_state_manager.trade_manager - cc_a_4: CCWallet = await CCWallet.create_new_cc(wallet_node_a.wallet_state_manager, wallet_a, uint64(100)) + cat_a_4: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node_a.wallet_state_manager, wallet_a, {"identifier": "genesis_by_id"}, uint64(100) + ) await asyncio.sleep(1) for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - await time_out_assert(15, cc_a_4.get_confirmed_balance, 100) + await time_out_assert(15, cat_a_4.get_confirmed_balance, 100) - colour = cc_a_4.get_colour() + asset_id = cat_a_4.get_asset_id() - cc_b_4: CCWallet = await CCWallet.create_wallet_for_cc(wallet_node_b.wallet_state_manager, wallet_b, colour) - cc_balance = await cc_a_4.get_confirmed_balance() - cc_balance_2 = await cc_b_4.get_confirmed_balance() - offer_dict = {1: -30, cc_a_4.id(): 50} + cat_b_4: CATWallet = await CATWallet.create_wallet_for_cat(wallet_node_b.wallet_state_manager, wallet_b, asset_id) + cat_balance = await cat_a_4.get_confirmed_balance() + cat_balance_2 = await cat_b_4.get_confirmed_balance() + offer_dict = {1: -30, cat_a_4.id(): 50} file = "test_offer_file.offer" file_path = Path(file) @@ -348,8 +357,8 @@ class TestCCTrades: for i in range(0, buffer_blocks): await full_node.farm_new_transaction_block(FarmNewBlockProtocol(token_bytes())) - await time_out_assert(15, cc_a_4.get_confirmed_balance, cc_balance - 50) - await time_out_assert(15, cc_b_4.get_confirmed_balance, cc_balance_2 + 50) + await time_out_assert(15, cat_a_4.get_confirmed_balance, cat_balance - 50) + await time_out_assert(15, cat_b_4.get_confirmed_balance, cat_balance_2 + 50) async def assert_func(): assert trade_a is not None @@ -367,11 +376,11 @@ class TestCCTrades: await time_out_assert(15, assert_func_b, TradeStatus.CONFIRMED.value) @pytest.mark.asyncio - async def test_cc_trade_cancel_insecure(self, wallets_prefarm): + async def test_cat_trade_cancel_insecure(self, wallets_prefarm): # Wallet A Wallet B - # CCWallet id 2: 50 CCWallet id 2: 50 - # CCWallet id 3: 50 CCWallet id 3: 50 - # CCWallet id 4: 40 CCWallet id 4: 60 + # CATWallet id 2: 50 CATWallet id 2: 50 + # CATWallet id 3: 50 CATWallet id 3: 50 + # CATWallet id 4: 40 CATWallet id 4: 60 # Wallet A will create offer, cancel it by deleting from db only wallet_node_a, wallet_node_b, full_node = wallets_prefarm wallet_a = wallet_node_a.wallet_state_manager.main_wallet @@ -414,11 +423,11 @@ class TestCCTrades: assert trade_a.status == TradeStatus.CANCELED.value @pytest.mark.asyncio - async def test_cc_trade_cancel_secure(self, wallets_prefarm): + async def test_cat_trade_cancel_secure(self, wallets_prefarm): # Wallet A Wallet B - # CCWallet id 2: 50 CCWallet id 2: 50 - # CCWallet id 3: 50 CCWallet id 3: 50 - # CCWallet id 4: 40 CCWallet id 4: 60 + # CATWallet id 2: 50 CATWallet id 2: 50 + # CATWallet id 3: 50 CATWallet id 3: 50 + # CATWallet id 4: 40 CATWallet id 4: 60 # Wallet A will create offer, cancel it by spending coins back to self wallet_node_a, wallet_node_b, full_node = wallets_prefarm diff --git a/tests/wallet/cat_wallet/test_cat_lifecycle.py b/tests/wallet/cat_wallet/test_cat_lifecycle.py new file mode 100644 index 0000000000..9fdcdf3063 --- /dev/null +++ b/tests/wallet/cat_wallet/test_cat_lifecycle.py @@ -0,0 +1,637 @@ +import pytest + +from typing import List, Tuple, Optional, Dict +from blspy import PrivateKey, AugSchemeMPL, G2Element +from clvm.casts import int_to_bytes + +from chia.clvm.spend_sim import SpendSim, SimClient +from chia.types.blockchain_format.program import Program +from chia.types.blockchain_format.coin import Coin +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.spend_bundle import SpendBundle +from chia.types.coin_spend import CoinSpend +from chia.types.mempool_inclusion_status import MempoolInclusionStatus +from chia.util.errors import Err +from chia.util.ints import uint64 +from chia.wallet.lineage_proof import LineageProof +from chia.wallet.cat_wallet.cat_utils import ( + CAT_MOD, + SpendableCAT, + construct_cat_puzzle, + unsigned_spend_bundle_for_spendable_cats, +) +from chia.wallet.puzzles.tails import ( + GenesisById, + GenesisByPuzhash, + EverythingWithSig, + DelegatedLimitations, +) + +from tests.clvm.test_puzzles import secret_exponent_for_index +from tests.clvm.benchmark_costs import cost_of_spend_bundle + +acs = Program.to(1) +acs_ph = acs.get_tree_hash() +NO_LINEAGE_PROOF = LineageProof() + + +class TestCATLifecycle: + cost: Dict[str, int] = {} + + @pytest.fixture(scope="function") + async def setup_sim(self): + sim = await SpendSim.create() + sim_client = SimClient(sim) + await sim.farm_block() + return sim, sim_client + + async def do_spend( + self, + sim: SpendSim, + sim_client: SimClient, + tail: Program, + coins: List[Coin], + lineage_proofs: List[Program], + inner_solutions: List[Program], + expected_result: Tuple[MempoolInclusionStatus, Err], + reveal_limitations_program: bool = True, + signatures: List[G2Element] = [], + extra_deltas: Optional[List[int]] = None, + additional_spends: List[SpendBundle] = [], + limitations_solutions: Optional[List[Program]] = None, + cost_str: str = "", + ): + if limitations_solutions is None: + limitations_solutions = [Program.to([])] * len(coins) + if extra_deltas is None: + extra_deltas = [0] * len(coins) + + spendable_cat_list: List[SpendableCAT] = [] + for coin, innersol, proof, limitations_solution, extra_delta in zip( + coins, inner_solutions, lineage_proofs, limitations_solutions, extra_deltas + ): + spendable_cat_list.append( + SpendableCAT( + coin, + tail.get_tree_hash(), + acs, + innersol, + limitations_solution=limitations_solution, + lineage_proof=proof, + extra_delta=extra_delta, + limitations_program_reveal=tail if reveal_limitations_program else Program.to([]), + ) + ) + + spend_bundle: SpendBundle = unsigned_spend_bundle_for_spendable_cats( + CAT_MOD, + spendable_cat_list, + ) + agg_sig = AugSchemeMPL.aggregate(signatures) + result = await sim_client.push_tx( + SpendBundle.aggregate( + [ + *additional_spends, + spend_bundle, + SpendBundle([], agg_sig), # "Signing" the spend bundle + ] + ) + ) + assert result == expected_result + self.cost[cost_str] = cost_of_spend_bundle(spend_bundle) + await sim.farm_block() + + @pytest.mark.asyncio() + async def test_cat_mod(self, setup_sim): + sim, sim_client = setup_sim + + try: + tail = Program.to([]) + checker_solution = Program.to([]) + cat_puzzle: Program = construct_cat_puzzle(CAT_MOD, tail.get_tree_hash(), acs) + cat_ph: bytes32 = cat_puzzle.get_tree_hash() + await sim.farm_block(cat_ph) + starting_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(cat_ph))[0].coin + + # Testing the eve spend + await self.do_spend( + sim, + sim_client, + tail, + [starting_coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), starting_coin.amount - 3, [b"memo"]], + [51, acs.get_tree_hash(), 1], + [51, acs.get_tree_hash(), 2], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution], + cost_str="Eve Spend", + ) + + # There's 4 total coins at this point. A farming reward and the three children of the spend above. + + # Testing a combination of two + coins: List[Coin] = [ + record.coin + for record in (await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False)) + ] + coins = [coins[0], coins[1]] + await self.do_spend( + sim, + sim_client, + tail, + coins, + [NO_LINEAGE_PROOF] * 2, + [ + Program.to( + [ + [51, acs.get_tree_hash(), coins[0].amount + coins[1].amount], + [51, 0, -113, tail, checker_solution], + ] + ), + Program.to([[51, 0, -113, tail, checker_solution]]), + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution] * 2, + cost_str="Two CATs", + ) + + # Testing a combination of three + coins = [ + record.coin + for record in (await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False)) + ] + total_amount: uint64 = uint64(sum([c.amount for c in coins])) + await self.do_spend( + sim, + sim_client, + tail, + coins, + [NO_LINEAGE_PROOF] * 3, + [ + Program.to( + [ + [51, acs.get_tree_hash(), total_amount], + [51, 0, -113, tail, checker_solution], + ] + ), + Program.to([[51, 0, -113, tail, checker_solution]]), + Program.to([[51, 0, -113, tail, checker_solution]]), + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution] * 3, + cost_str="Three CATs", + ) + + # Spend with a standard lineage proof + parent_coin: Coin = coins[0] # The first one is the one we didn't light on fire + _, curried_args = cat_puzzle.uncurry() + _, _, innerpuzzle = curried_args.as_iter() + lineage_proof = LineageProof(parent_coin.parent_coin_info, innerpuzzle.get_tree_hash(), parent_coin.amount) + await self.do_spend( + sim, + sim_client, + tail, + [(await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin], + [lineage_proof], + [Program.to([[51, acs.get_tree_hash(), total_amount]])], + (MempoolInclusionStatus.SUCCESS, None), + reveal_limitations_program=False, + cost_str="Standard Lineage Check", + ) + + # Melt some value + await self.do_spend( + sim, + sim_client, + tail, + [(await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), total_amount - 1], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + extra_deltas=[-1], + limitations_solutions=[checker_solution], + cost_str="Melting Value", + ) + + # Mint some value + temp_p = Program.to(1) + temp_ph: bytes32 = temp_p.get_tree_hash() + await sim.farm_block(temp_ph) + acs_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(temp_ph, include_spent_coins=False))[ + 0 + ].coin + acs_bundle = SpendBundle( + [ + CoinSpend( + acs_coin, + temp_p, + Program.to([]), + ) + ], + G2Element(), + ) + await self.do_spend( + sim, + sim_client, + tail, + [(await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), total_amount], + [51, 0, -113, tail, checker_solution], + ] + ) + ], # We subtracted 1 last time so it's normal now + (MempoolInclusionStatus.SUCCESS, None), + extra_deltas=[1], + additional_spends=[acs_bundle], + limitations_solutions=[checker_solution], + cost_str="Mint Value", + ) + + finally: + await sim.close() + + @pytest.mark.asyncio() + async def test_complex_spend(self, setup_sim): + sim, sim_client = setup_sim + + try: + tail = Program.to([]) + checker_solution = Program.to([]) + cat_puzzle: Program = construct_cat_puzzle(CAT_MOD, tail.get_tree_hash(), acs) + cat_ph: bytes32 = cat_puzzle.get_tree_hash() + await sim.farm_block(cat_ph) + await sim.farm_block(cat_ph) + + cat_records = await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False) + parent_of_mint = cat_records[0].coin + parent_of_melt = cat_records[1].coin + eve_to_mint = cat_records[2].coin + eve_to_melt = cat_records[3].coin + + # Spend two of them to make them non-eve + await self.do_spend( + sim, + sim_client, + tail, + [parent_of_mint, parent_of_melt], + [NO_LINEAGE_PROOF, NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), parent_of_mint.amount], + [51, 0, -113, tail, checker_solution], + ] + ), + Program.to( + [ + [51, acs.get_tree_hash(), parent_of_melt.amount], + [51, 0, -113, tail, checker_solution], + ] + ), + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution] * 2, + cost_str="Spend two eves", + ) + + # Make the lineage proofs for the non-eves + mint_lineage = LineageProof(parent_of_mint.parent_coin_info, acs_ph, parent_of_mint.amount) + melt_lineage = LineageProof(parent_of_melt.parent_coin_info, acs_ph, parent_of_melt.amount) + + # Find the two new coins + all_cats = await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False) + all_cat_coins = [cr.coin for cr in all_cats] + standard_to_mint = list(filter(lambda cr: cr.parent_coin_info == parent_of_mint.name(), all_cat_coins))[0] + standard_to_melt = list(filter(lambda cr: cr.parent_coin_info == parent_of_melt.name(), all_cat_coins))[0] + + # Do the complex spend + # We have both and eve and non-eve doing both minting and melting + await self.do_spend( + sim, + sim_client, + tail, + [eve_to_mint, eve_to_melt, standard_to_mint, standard_to_melt], + [NO_LINEAGE_PROOF, NO_LINEAGE_PROOF, mint_lineage, melt_lineage], + [ + Program.to( + [ + [51, acs.get_tree_hash(), eve_to_mint.amount + 13], + [51, 0, -113, tail, checker_solution], + ] + ), + Program.to( + [ + [51, acs.get_tree_hash(), eve_to_melt.amount - 21], + [51, 0, -113, tail, checker_solution], + ] + ), + Program.to( + [ + [51, acs.get_tree_hash(), standard_to_mint.amount + 21], + [51, 0, -113, tail, checker_solution], + ] + ), + Program.to( + [ + [51, acs.get_tree_hash(), standard_to_melt.amount - 13], + [51, 0, -113, tail, checker_solution], + ] + ), + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution] * 4, + extra_deltas=[13, -21, 21, -13], + cost_str="Complex Spend", + ) + finally: + await sim.close() + + @pytest.mark.asyncio() + async def test_genesis_by_id(self, setup_sim): + sim, sim_client = setup_sim + + try: + standard_acs = Program.to(1) + standard_acs_ph: bytes32 = standard_acs.get_tree_hash() + await sim.farm_block(standard_acs_ph) + + starting_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(standard_acs_ph))[0].coin + tail: Program = GenesisById.construct([Program.to(starting_coin.name())]) + checker_solution: Program = GenesisById.solve([], {}) + cat_puzzle: Program = construct_cat_puzzle(CAT_MOD, tail.get_tree_hash(), acs) + cat_ph: bytes32 = cat_puzzle.get_tree_hash() + + await sim_client.push_tx( + SpendBundle( + [CoinSpend(starting_coin, standard_acs, Program.to([[51, cat_ph, starting_coin.amount]]))], + G2Element(), + ) + ) + await sim.farm_block() + + await self.do_spend( + sim, + sim_client, + tail, + [(await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), starting_coin.amount], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution], + cost_str="Genesis by ID", + ) + + finally: + await sim.close() + + @pytest.mark.asyncio() + async def test_genesis_by_puzhash(self, setup_sim): + sim, sim_client = setup_sim + + try: + standard_acs = Program.to(1) + standard_acs_ph: bytes32 = standard_acs.get_tree_hash() + await sim.farm_block(standard_acs_ph) + + starting_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(standard_acs_ph))[0].coin + tail: Program = GenesisByPuzhash.construct([Program.to(starting_coin.puzzle_hash)]) + checker_solution: Program = GenesisByPuzhash.solve([], starting_coin.to_json_dict()) + cat_puzzle: Program = construct_cat_puzzle(CAT_MOD, tail.get_tree_hash(), acs) + cat_ph: bytes32 = cat_puzzle.get_tree_hash() + + await sim_client.push_tx( + SpendBundle( + [CoinSpend(starting_coin, standard_acs, Program.to([[51, cat_ph, starting_coin.amount]]))], + G2Element(), + ) + ) + await sim.farm_block() + + await self.do_spend( + sim, + sim_client, + tail, + [(await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), starting_coin.amount], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution], + cost_str="Genesis by Puzhash", + ) + + finally: + await sim.close() + + @pytest.mark.asyncio() + async def test_everything_with_signature(self, setup_sim): + sim, sim_client = setup_sim + + try: + sk = PrivateKey.from_bytes(secret_exponent_for_index(1).to_bytes(32, "big")) + tail: Program = EverythingWithSig.construct([Program.to(sk.get_g1())]) + checker_solution: Program = EverythingWithSig.solve([], {}) + cat_puzzle: Program = construct_cat_puzzle(CAT_MOD, tail.get_tree_hash(), acs) + cat_ph: bytes32 = cat_puzzle.get_tree_hash() + await sim.farm_block(cat_ph) + + # Test eve spend + # We don't sign any message data because CLVM 0 translates to b'' apparently + starting_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(cat_ph))[0].coin + signature: G2Element = AugSchemeMPL.sign( + sk, (starting_coin.name() + sim.defaults.AGG_SIG_ME_ADDITIONAL_DATA) + ) + + await self.do_spend( + sim, + sim_client, + tail, + [starting_coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), starting_coin.amount], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + limitations_solutions=[checker_solution], + signatures=[signature], + cost_str="Signature Issuance", + ) + + # Test melting value + coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin + signature = AugSchemeMPL.sign( + sk, (int_to_bytes(-1) + coin.name() + sim.defaults.AGG_SIG_ME_ADDITIONAL_DATA) + ) + + await self.do_spend( + sim, + sim_client, + tail, + [coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), coin.amount - 1], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + extra_deltas=[-1], + limitations_solutions=[checker_solution], + signatures=[signature], + cost_str="Signature Melt", + ) + + # Test minting value + coin = (await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin + signature = AugSchemeMPL.sign(sk, (int_to_bytes(1) + coin.name() + sim.defaults.AGG_SIG_ME_ADDITIONAL_DATA)) + + # Need something to fund the minting + temp_p = Program.to(1) + temp_ph: bytes32 = temp_p.get_tree_hash() + await sim.farm_block(temp_ph) + acs_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(temp_ph, include_spent_coins=False))[ + 0 + ].coin + acs_bundle = SpendBundle( + [ + CoinSpend( + acs_coin, + temp_p, + Program.to([]), + ) + ], + G2Element(), + ) + + await self.do_spend( + sim, + sim_client, + tail, + [coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), coin.amount + 1], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + extra_deltas=[1], + limitations_solutions=[checker_solution], + signatures=[signature], + additional_spends=[acs_bundle], + cost_str="Signature Mint", + ) + + finally: + await sim.close() + + @pytest.mark.asyncio() + async def test_delegated_tail(self, setup_sim): + sim, sim_client = setup_sim + + try: + standard_acs = Program.to(1) + standard_acs_ph: bytes32 = standard_acs.get_tree_hash() + await sim.farm_block(standard_acs_ph) + + starting_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(standard_acs_ph))[0].coin + sk = PrivateKey.from_bytes(secret_exponent_for_index(1).to_bytes(32, "big")) + tail: Program = DelegatedLimitations.construct([Program.to(sk.get_g1())]) + cat_puzzle: Program = construct_cat_puzzle(CAT_MOD, tail.get_tree_hash(), acs) + cat_ph: bytes32 = cat_puzzle.get_tree_hash() + + await sim_client.push_tx( + SpendBundle( + [CoinSpend(starting_coin, standard_acs, Program.to([[51, cat_ph, starting_coin.amount]]))], + G2Element(), + ) + ) + await sim.farm_block() + + # We're signing a different tail to use here + name_as_program = Program.to(starting_coin.name()) + new_tail: Program = GenesisById.construct([name_as_program]) + checker_solution: Program = DelegatedLimitations.solve( + [name_as_program], + { + "signed_program": { + "identifier": "genesis_by_id", + "args": [str(name_as_program)], + }, + "program_arguments": {}, + }, + ) + signature: G2Element = AugSchemeMPL.sign(sk, new_tail.get_tree_hash()) + + await self.do_spend( + sim, + sim_client, + tail, + [(await sim_client.get_coin_records_by_puzzle_hash(cat_ph, include_spent_coins=False))[0].coin], + [NO_LINEAGE_PROOF], + [ + Program.to( + [ + [51, acs.get_tree_hash(), starting_coin.amount], + [51, 0, -113, tail, checker_solution], + ] + ) + ], + (MempoolInclusionStatus.SUCCESS, None), + signatures=[signature], + limitations_solutions=[checker_solution], + cost_str="Delegated Genesis", + ) + + finally: + await sim.close() + + def test_cost(self): + import json + import logging + + log = logging.getLogger(__name__) + log.warning(json.dumps(self.cost)) diff --git a/tests/wallet/cat_wallet/test_cat_wallet.py b/tests/wallet/cat_wallet/test_cat_wallet.py new file mode 100644 index 0000000000..20c4539fd9 --- /dev/null +++ b/tests/wallet/cat_wallet/test_cat_wallet.py @@ -0,0 +1,751 @@ +import asyncio +from typing import List + +import pytest + +from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward +from chia.full_node.mempool_manager import MempoolManager +from chia.simulator.simulator_protocol import FarmNewBlockProtocol +from chia.types.blockchain_format.coin import Coin +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.peer_info import PeerInfo +from chia.util.ints import uint16, uint32, uint64 +from chia.wallet.cat_wallet.cat_utils import construct_cat_puzzle +from chia.wallet.cat_wallet.cat_wallet import CATWallet +from chia.wallet.cat_wallet.cat_constants import DEFAULT_CATS +from chia.wallet.puzzles.cat_loader import CAT_MOD +from chia.wallet.transaction_record import TransactionRecord +from tests.setup_nodes import setup_simulators_and_wallets +from tests.time_out_assert import time_out_assert + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + + +async def tx_in_pool(mempool: MempoolManager, tx_id: bytes32): + tx = mempool.get_spendbundle(tx_id) + if tx is None: + return False + return True + + +class TestCATWallet: + @pytest.fixture(scope="function") + async def wallet_node(self): + async for _ in setup_simulators_and_wallets(1, 1, {}): + yield _ + + @pytest.fixture(scope="function") + async def two_wallet_nodes(self): + async for _ in setup_simulators_and_wallets(1, 2, {}): + yield _ + + @pytest.fixture(scope="function") + async def three_wallet_nodes(self): + async for _ in setup_simulators_and_wallets(1, 3, {}): + yield _ + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_cat_creation(self, two_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = two_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node, server_2 = wallets[0] + wallet = wallet_node.wallet_state_manager.main_wallet + + ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + + await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet.get_confirmed_balance, funds) + + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100) + ) + # The next 2 lines are basically a noop, it just adds test coverage + cat_wallet = await CATWallet.create(wallet_node.wallet_state_manager, wallet, cat_wallet.wallet_info) + await wallet_node.wallet_state_manager.add_new_wallet(cat_wallet, cat_wallet.id()) + + tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() + tx_record = tx_queue[0] + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_spendable_balance, 100) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_cat_spend(self, two_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = two_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node, server_2 = wallets[0] + wallet_node_2, server_3 = wallets[1] + wallet = wallet_node.wallet_state_manager.main_wallet + wallet2 = wallet_node_2.wallet_state_manager.main_wallet + + ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_2.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet.get_confirmed_balance, funds) + + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100) + ) + tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() + tx_record = tx_queue[0] + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) + + assert cat_wallet.cat_info.limitations_program_hash is not None + asset_id = cat_wallet.get_asset_id() + + cat_wallet_2: CATWallet = await CATWallet.create_wallet_for_cat(wallet_node_2.wallet_state_manager, wallet2, asset_id) + + assert cat_wallet.cat_info.limitations_program_hash == cat_wallet_2.cat_info.limitations_program_hash + + cat_2_hash = await cat_wallet_2.get_new_inner_hash() + tx_records = await cat_wallet.generate_signed_transaction([uint64(60)], [cat_2_hash], fee=uint64(1)) + for tx_record in tx_records: + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + if tx_record.spend_bundle is not None: + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + await time_out_assert(15, cat_wallet.get_pending_change_balance, 40) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + await time_out_assert(30, wallet.get_confirmed_balance, funds * 2 - 101) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 40) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 40) + + await time_out_assert(30, cat_wallet_2.get_confirmed_balance, 60) + await time_out_assert(30, cat_wallet_2.get_unconfirmed_balance, 60) + + cat_hash = await cat_wallet.get_new_inner_hash() + tx_records = await cat_wallet_2.generate_signed_transaction([uint64(15)], [cat_hash]) + for tx_record in tx_records: + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 55) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 55) + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_get_wallet_for_asset_id(self, two_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = two_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node, server_2 = wallets[0] + wallet = wallet_node.wallet_state_manager.main_wallet + + ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet.get_confirmed_balance, funds) + + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100) + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + asset_id = cat_wallet.get_asset_id() + await cat_wallet.set_tail_program(bytes(cat_wallet.cat_info.my_tail).hex()) + assert await wallet_node.wallet_state_manager.get_wallet_for_asset_id(asset_id) == cat_wallet + + # Test that the a default CAT will initialize correctly + asset = DEFAULT_CATS[next(iter(DEFAULT_CATS))] + asset_id = asset["asset_id"] + cat_wallet_2 = await CATWallet.create_wallet_for_cat(wallet_node.wallet_state_manager, wallet, asset_id) + assert await cat_wallet_2.get_name() == asset["name"] + await cat_wallet_2.set_name("Test Name") + assert await cat_wallet_2.get_name() == "Test Name" + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_cat_doesnt_see_eve(self, two_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = two_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node, server_2 = wallets[0] + wallet_node_2, server_3 = wallets[1] + wallet = wallet_node.wallet_state_manager.main_wallet + wallet2 = wallet_node_2.wallet_state_manager.main_wallet + + ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_2.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet.get_confirmed_balance, funds) + + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100) + ) + tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() + tx_record = tx_queue[0] + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) + + assert cat_wallet.cat_info.limitations_program_hash is not None + asset_id = cat_wallet.get_asset_id() + + cat_wallet_2: CATWallet = await CATWallet.create_wallet_for_cat(wallet_node_2.wallet_state_manager, wallet2, asset_id) + + assert cat_wallet.cat_info.limitations_program_hash == cat_wallet_2.cat_info.limitations_program_hash + + cat_2_hash = await cat_wallet_2.get_new_inner_hash() + tx_records = await cat_wallet.generate_signed_transaction([uint64(60)], [cat_2_hash]) + for tx_record in tx_records: + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 40) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 40) + + await time_out_assert(15, cat_wallet_2.get_confirmed_balance, 60) + await time_out_assert(15, cat_wallet_2.get_unconfirmed_balance, 60) + + cc2_ph = await cat_wallet_2.get_new_cat_puzzle_hash() + tx_record = await wallet.wallet_state_manager.main_wallet.generate_signed_transaction(10, cc2_ph, 0) + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(0, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + id = cat_wallet_2.id() + wsm = cat_wallet_2.wallet_state_manager + + async def query_and_assert_transactions(wsm, id): + all_txs = await wsm.tx_store.get_all_transactions_for_wallet(id) + return len(list(filter(lambda tx: tx.amount == 10, all_txs))) + + await time_out_assert(15, query_and_assert_transactions, 0, wsm, id) + await time_out_assert(15, wsm.get_confirmed_balance_for_wallet, 60, id) + await time_out_assert(15, cat_wallet_2.get_confirmed_balance, 60) + await time_out_assert(15, cat_wallet_2.get_unconfirmed_balance, 60) + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_cat_spend_multiple(self, three_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = three_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node_0, wallet_server_0 = wallets[0] + wallet_node_1, wallet_server_1 = wallets[1] + wallet_node_2, wallet_server_2 = wallets[2] + wallet_0 = wallet_node_0.wallet_state_manager.main_wallet + wallet_1 = wallet_node_1.wallet_state_manager.main_wallet + wallet_2 = wallet_node_2.wallet_state_manager.main_wallet + + ph = await wallet_0.get_new_puzzlehash() + if trusted: + wallet_node_0.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_1.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_2.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node_0.config["trusted_peers"] = {} + wallet_node_1.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} + await wallet_server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await wallet_server_1.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await wallet_server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet_0.get_confirmed_balance, funds) + + async with wallet_node_0.wallet_state_manager.lock: + cat_wallet_0: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node_0.wallet_state_manager, wallet_0, {"identifier": "genesis_by_id"}, uint64(100) + ) + tx_queue: List[TransactionRecord] = await wallet_node_0.wallet_state_manager.tx_store.get_not_sent() + tx_record = tx_queue[0] + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet_0.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet_0.get_unconfirmed_balance, 100) + + assert cat_wallet_0.cat_info.limitations_program_hash is not None + asset_id = cat_wallet_0.get_asset_id() + + cat_wallet_1: CATWallet = await CATWallet.create_wallet_for_cat( + wallet_node_1.wallet_state_manager, wallet_1, asset_id + ) + + cat_wallet_2: CATWallet = await CATWallet.create_wallet_for_cat( + wallet_node_2.wallet_state_manager, wallet_2, asset_id + ) + + assert cat_wallet_0.cat_info.limitations_program_hash == cat_wallet_1.cat_info.limitations_program_hash + assert cat_wallet_0.cat_info.limitations_program_hash == cat_wallet_2.cat_info.limitations_program_hash + + cat_1_hash = await cat_wallet_1.get_new_inner_hash() + cat_2_hash = await cat_wallet_2.get_new_inner_hash() + + tx_records = await cat_wallet_0.generate_signed_transaction([uint64(60), uint64(20)], [cat_1_hash, cat_2_hash]) + for tx_record in tx_records: + await wallet_0.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet_0.get_confirmed_balance, 20) + await time_out_assert(15, cat_wallet_0.get_unconfirmed_balance, 20) + + await time_out_assert(30, cat_wallet_1.get_confirmed_balance, 60) + await time_out_assert(30, cat_wallet_1.get_unconfirmed_balance, 60) + + await time_out_assert(30, cat_wallet_2.get_confirmed_balance, 20) + await time_out_assert(30, cat_wallet_2.get_unconfirmed_balance, 20) + + cat_hash = await cat_wallet_0.get_new_inner_hash() + + tx_records = await cat_wallet_1.generate_signed_transaction([uint64(15)], [cat_hash]) + for tx_record in tx_records: + await wallet_1.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + tx_records_2 = await cat_wallet_2.generate_signed_transaction([uint64(20)], [cat_hash]) + for tx_record in tx_records_2: + await wallet_2.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet_0.get_confirmed_balance, 55) + await time_out_assert(15, cat_wallet_0.get_unconfirmed_balance, 55) + + await time_out_assert(30, cat_wallet_1.get_confirmed_balance, 45) + await time_out_assert(30, cat_wallet_1.get_unconfirmed_balance, 45) + + await time_out_assert(30, cat_wallet_2.get_confirmed_balance, 0) + await time_out_assert(30, cat_wallet_2.get_unconfirmed_balance, 0) + + txs = await wallet_1.wallet_state_manager.tx_store.get_transactions_between(cat_wallet_1.id(), 0, 100000) + print(len(txs)) + # Test with Memo + tx_records_3: TransactionRecord = await cat_wallet_1.generate_signed_transaction( + [uint64(30)], [cat_hash], memos=[[b"Markus Walburg"]] + ) + with pytest.raises(ValueError): + await cat_wallet_1.generate_signed_transaction( + [uint64(30)], [cat_hash], memos=[[b"too"], [b"many"], [b"memos"]] + ) + + for tx_record in tx_records_3: + await wallet_1.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + txs = await wallet_1.wallet_state_manager.tx_store.get_transactions_between(cat_wallet_1.id(), 0, 100000) + for tx in txs: + if tx.amount == 30: + memos = tx.get_memos() + assert len(memos) == 1 + assert b"Markus Walburg" in [v for v_list in memos.values() for v in v_list] + assert list(memos.keys())[0] in [a.name() for a in tx.spend_bundle.additions()] + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_cat_max_amount_send(self, two_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = two_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node, server_2 = wallets[0] + wallet_node_2, server_3 = wallets[1] + wallet = wallet_node.wallet_state_manager.main_wallet + + ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_2.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet.get_confirmed_balance, funds) + + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100000) + ) + tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() + tx_record = tx_queue[0] + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100000) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100000) + + assert cat_wallet.cat_info.limitations_program_hash is not None + + cat_2 = await cat_wallet.get_new_inner_puzzle() + cat_2_hash = cat_2.get_tree_hash() + amounts = [] + puzzle_hashes = [] + for i in range(1, 50): + amounts.append(uint64(i)) + puzzle_hashes.append(cat_2_hash) + spent_coint = (await cat_wallet.get_cat_spendable_coins())[0].coin + tx_records = await cat_wallet.generate_signed_transaction(amounts, puzzle_hashes, coins={spent_coint}) + for tx_record in tx_records: + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + await asyncio.sleep(2) + + async def check_all_there(): + spendable = await cat_wallet.get_cat_spendable_coins() + spendable_name_set = set() + for record in spendable: + spendable_name_set.add(record.coin.name()) + puzzle_hash = construct_cat_puzzle(CAT_MOD, cat_wallet.cat_info.limitations_program_hash, cat_2).get_tree_hash() + for i in range(1, 50): + coin = Coin(spent_coint.name(), puzzle_hash, i) + if coin.name() not in spendable_name_set: + return False + return True + + await time_out_assert(15, check_all_there, True) + await asyncio.sleep(5) + max_sent_amount = await cat_wallet.get_max_send_amount() + + # 1) Generate transaction that is under the limit + under_limit_txs = None + try: + under_limit_txs = await cat_wallet.generate_signed_transaction( + [max_sent_amount - 1], + [ph], + ) + except ValueError: + assert ValueError + + assert under_limit_txs is not None + + # 2) Generate transaction that is equal to limit + at_limit_txs = None + try: + at_limit_txs = await cat_wallet.generate_signed_transaction( + [max_sent_amount], + [ph], + ) + except ValueError: + assert ValueError + + assert at_limit_txs is not None + + # 3) Generate transaction that is greater than limit + above_limit_txs = None + try: + above_limit_txs = await cat_wallet.generate_signed_transaction( + [max_sent_amount + 1], + [ph], + ) + except ValueError: + pass + + assert above_limit_txs is None + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_cat_hint(self, two_wallet_nodes, trusted): + num_blocks = 3 + full_nodes, wallets = two_wallet_nodes + full_node_api = full_nodes[0] + full_node_server = full_node_api.server + wallet_node, server_2 = wallets[0] + wallet_node_2, server_3 = wallets[1] + wallet = wallet_node.wallet_state_manager.main_wallet + wallet2 = wallet_node_2.wallet_state_manager.main_wallet + + ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_2.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks - 1) + ] + ) + + await time_out_assert(15, wallet.get_confirmed_balance, funds) + + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100) + ) + tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() + tx_record = tx_queue[0] + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) + assert cat_wallet.cat_info.limitations_program_hash is not None + + cat_2_hash = await wallet2.get_new_puzzlehash() + tx_records = await cat_wallet.generate_signed_transaction([uint64(60)], [cat_2_hash], memos=[[cat_2_hash]]) + + for tx_record in tx_records: + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 40) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 40) + + async def check_wallets(wallet_node): + return len(wallet_node.wallet_state_manager.wallets.keys()) + + await time_out_assert(10, check_wallets, 2, wallet_node_2) + cat_wallet_2 = wallet_node_2.wallet_state_manager.wallets[2] + + await time_out_assert(30, cat_wallet_2.get_confirmed_balance, 60) + await time_out_assert(30, cat_wallet_2.get_unconfirmed_balance, 60) + + cat_hash = await cat_wallet.get_new_inner_hash() + tx_records = await cat_wallet_2.generate_signed_transaction([uint64(15)], [cat_hash]) + for tx_record in tx_records: + await wallet.wallet_state_manager.add_pending_transaction(tx_record) + + await time_out_assert( + 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + ) + + for i in range(1, num_blocks): + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + await time_out_assert(15, cat_wallet.get_confirmed_balance, 55) + await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 55) + + # @pytest.mark.asyncio + + # async def test_cat_melt_and_mint(self, two_wallet_nodes): + # num_blocks = 3 + # full_nodes, wallets = two_wallet_nodes + # full_node_api = full_nodes[0] + # full_node_server = full_node_api.server + # wallet_node, server_2 = wallets[0] + # wallet_node_2, server_3 = wallets[1] + # wallet = wallet_node.wallet_state_manager.main_wallet + # + # ph = await wallet.get_new_puzzlehash() + # + # await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + # await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + # + # for i in range(1, num_blocks): + # await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + # + # funds = sum( + # [ + # calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + # for i in range(1, num_blocks - 1) + # ] + # ) + # + # await time_out_assert(15, wallet.get_confirmed_balance, funds) + # + # async with wallet_node.wallet_state_manager.lock: + # cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + # wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100000) + # ) + # tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() + # tx_record = tx_queue[0] + # await time_out_assert( + # 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() + # ) + # for i in range(1, num_blocks): + # await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) + # + # await time_out_assert(15, cat_wallet.get_confirmed_balance, 100000) + # await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100000) diff --git a/tests/wallet/cc_wallet/test_cc_wallet.py b/tests/wallet/cc_wallet/test_cc_wallet.py deleted file mode 100644 index 0c4afe009b..0000000000 --- a/tests/wallet/cc_wallet/test_cc_wallet.py +++ /dev/null @@ -1,549 +0,0 @@ -import asyncio -from typing import List - -import pytest - -from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward -from chia.full_node.mempool_manager import MempoolManager -from chia.simulator.simulator_protocol import FarmNewBlockProtocol -from chia.types.blockchain_format.coin import Coin -from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.types.peer_info import PeerInfo -from chia.util.ints import uint16, uint32, uint64 -from chia.wallet.cc_wallet.cc_utils import cc_puzzle_hash_for_inner_puzzle_hash -from chia.wallet.cc_wallet.cc_wallet import CCWallet -from chia.wallet.puzzles.cc_loader import CC_MOD -from chia.wallet.transaction_record import TransactionRecord -from chia.wallet.wallet_coin_record import WalletCoinRecord -from tests.setup_nodes import setup_simulators_and_wallets -from tests.time_out_assert import time_out_assert - - -@pytest.fixture(scope="module") -def event_loop(): - loop = asyncio.get_event_loop() - yield loop - - -async def tx_in_pool(mempool: MempoolManager, tx_id: bytes32): - tx = mempool.get_spendbundle(tx_id) - if tx is None: - return False - return True - - -class TestCCWallet: - @pytest.fixture(scope="function") - async def wallet_node(self): - async for _ in setup_simulators_and_wallets(1, 1, {}): - yield _ - - @pytest.fixture(scope="function") - async def two_wallet_nodes(self): - async for _ in setup_simulators_and_wallets(1, 2, {}): - yield _ - - @pytest.fixture(scope="function") - async def three_wallet_nodes(self): - async for _ in setup_simulators_and_wallets(1, 3, {}): - yield _ - - @pytest.mark.asyncio - async def test_colour_creation(self, two_wallet_nodes): - num_blocks = 3 - full_nodes, wallets = two_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node, server_2 = wallets[0] - wallet = wallet_node.wallet_state_manager.main_wallet - - ph = await wallet.get_new_puzzlehash() - - await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - - await time_out_assert(15, wallet.get_confirmed_balance, funds) - - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node.wallet_state_manager, wallet, uint64(100)) - tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() - tx_record = tx_queue[0] - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100) - - @pytest.mark.asyncio - async def test_cc_spend(self, two_wallet_nodes): - num_blocks = 3 - full_nodes, wallets = two_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node, server_2 = wallets[0] - wallet_node_2, server_3 = wallets[1] - wallet = wallet_node.wallet_state_manager.main_wallet - wallet2 = wallet_node_2.wallet_state_manager.main_wallet - - ph = await wallet.get_new_puzzlehash() - - await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - - await time_out_assert(15, wallet.get_confirmed_balance, funds) - - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node.wallet_state_manager, wallet, uint64(100)) - tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() - tx_record = tx_queue[0] - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100) - - assert cc_wallet.cc_info.my_genesis_checker is not None - colour = cc_wallet.get_colour() - - cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc(wallet_node_2.wallet_state_manager, wallet2, colour) - - assert cc_wallet.cc_info.my_genesis_checker == cc_wallet_2.cc_info.my_genesis_checker - - cc_2_hash = await cc_wallet_2.get_new_inner_hash() - tx_record = await cc_wallet.generate_signed_transaction([uint64(60)], [cc_2_hash]) - await wallet.wallet_state_manager.add_pending_transaction(tx_record) - - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 40) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 40) - - await time_out_assert(30, cc_wallet_2.get_confirmed_balance, 60) - await time_out_assert(30, cc_wallet_2.get_unconfirmed_balance, 60) - - cc_hash = await cc_wallet.get_new_inner_hash() - tx_record = await cc_wallet_2.generate_signed_transaction([uint64(15)], [cc_hash]) - await wallet.wallet_state_manager.add_pending_transaction(tx_record) - - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 55) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 55) - - @pytest.mark.asyncio - async def test_get_wallet_for_colour(self, two_wallet_nodes): - num_blocks = 3 - full_nodes, wallets = two_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node, server_2 = wallets[0] - wallet = wallet_node.wallet_state_manager.main_wallet - - ph = await wallet.get_new_puzzlehash() - - await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - - await time_out_assert(15, wallet.get_confirmed_balance, funds) - - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node.wallet_state_manager, wallet, uint64(100)) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - colour = cc_wallet.get_colour() - assert await wallet_node.wallet_state_manager.get_wallet_for_colour(colour) == cc_wallet - - @pytest.mark.asyncio - async def test_generate_zero_val(self, two_wallet_nodes): - num_blocks = 4 - full_nodes, wallets = two_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node, server_2 = wallets[0] - wallet_node_2, server_3 = wallets[1] - wallet = wallet_node.wallet_state_manager.main_wallet - wallet2 = wallet_node_2.wallet_state_manager.main_wallet - - ph = await wallet.get_new_puzzlehash() - - await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - await time_out_assert(15, wallet.get_confirmed_balance, funds) - - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node.wallet_state_manager, wallet, uint64(100)) - await asyncio.sleep(1) - - ph = await wallet2.get_new_puzzlehash() - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100) - - assert cc_wallet.cc_info.my_genesis_checker is not None - colour = cc_wallet.get_colour() - - cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc(wallet_node_2.wallet_state_manager, wallet2, colour) - await asyncio.sleep(1) - - assert cc_wallet.cc_info.my_genesis_checker == cc_wallet_2.cc_info.my_genesis_checker - - spend_bundle = await cc_wallet_2.generate_zero_val_coin() - await asyncio.sleep(1) - await time_out_assert(15, tx_in_pool, True, full_node_api.full_node.mempool_manager, spend_bundle.name()) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - async def unspent_count(): - unspent: List[WalletCoinRecord] = list( - await cc_wallet_2.wallet_state_manager.get_spendable_coins_for_wallet(cc_wallet_2.id()) - ) - return len(unspent) - - await time_out_assert(15, unspent_count, 1) - unspent: List[WalletCoinRecord] = list( - await cc_wallet_2.wallet_state_manager.get_spendable_coins_for_wallet(cc_wallet_2.id()) - ) - assert unspent.pop().coin.amount == 0 - - @pytest.mark.asyncio - async def test_cc_spend_uncoloured(self, two_wallet_nodes): - num_blocks = 3 - full_nodes, wallets = two_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node, server_2 = wallets[0] - wallet_node_2, server_3 = wallets[1] - wallet = wallet_node.wallet_state_manager.main_wallet - wallet2 = wallet_node_2.wallet_state_manager.main_wallet - - ph = await wallet.get_new_puzzlehash() - - await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - - await time_out_assert(15, wallet.get_confirmed_balance, funds) - - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node.wallet_state_manager, wallet, uint64(100)) - tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() - tx_record = tx_queue[0] - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100) - - assert cc_wallet.cc_info.my_genesis_checker is not None - colour = cc_wallet.get_colour() - - cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc(wallet_node_2.wallet_state_manager, wallet2, colour) - - assert cc_wallet.cc_info.my_genesis_checker == cc_wallet_2.cc_info.my_genesis_checker - - cc_2_hash = await cc_wallet_2.get_new_inner_hash() - tx_record = await cc_wallet.generate_signed_transaction([uint64(60)], [cc_2_hash]) - await wallet.wallet_state_manager.add_pending_transaction(tx_record) - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 40) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 40) - - await time_out_assert(15, cc_wallet_2.get_confirmed_balance, 60) - await time_out_assert(15, cc_wallet_2.get_unconfirmed_balance, 60) - - cc2_ph = await cc_wallet_2.get_new_cc_puzzle_hash() - tx_record = await wallet.wallet_state_manager.main_wallet.generate_signed_transaction(10, cc2_ph, 0) - await wallet.wallet_state_manager.add_pending_transaction(tx_record) - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(0, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - id = cc_wallet_2.id() - wsm = cc_wallet_2.wallet_state_manager - await time_out_assert(15, wsm.get_confirmed_balance_for_wallet, 70, id) - await time_out_assert(15, cc_wallet_2.get_confirmed_balance, 60) - await time_out_assert(15, cc_wallet_2.get_unconfirmed_balance, 60) - - @pytest.mark.asyncio - async def test_cc_spend_multiple(self, three_wallet_nodes): - num_blocks = 3 - full_nodes, wallets = three_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node_0, wallet_server_0 = wallets[0] - wallet_node_1, wallet_server_1 = wallets[1] - wallet_node_2, wallet_server_2 = wallets[2] - wallet_0 = wallet_node_0.wallet_state_manager.main_wallet - wallet_1 = wallet_node_1.wallet_state_manager.main_wallet - wallet_2 = wallet_node_2.wallet_state_manager.main_wallet - - ph = await wallet_0.get_new_puzzlehash() - - await wallet_server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - await wallet_server_1.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - await wallet_server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - - await time_out_assert(15, wallet_0.get_confirmed_balance, funds) - - cc_wallet_0: CCWallet = await CCWallet.create_new_cc(wallet_node_0.wallet_state_manager, wallet_0, uint64(100)) - tx_queue: List[TransactionRecord] = await wallet_node_0.wallet_state_manager.tx_store.get_not_sent() - tx_record = tx_queue[0] - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet_0.get_confirmed_balance, 100) - await time_out_assert(15, cc_wallet_0.get_unconfirmed_balance, 100) - - assert cc_wallet_0.cc_info.my_genesis_checker is not None - colour = cc_wallet_0.get_colour() - - cc_wallet_1: CCWallet = await CCWallet.create_wallet_for_cc( - wallet_node_1.wallet_state_manager, wallet_1, colour - ) - - cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc( - wallet_node_2.wallet_state_manager, wallet_2, colour - ) - - assert cc_wallet_0.cc_info.my_genesis_checker == cc_wallet_1.cc_info.my_genesis_checker - assert cc_wallet_0.cc_info.my_genesis_checker == cc_wallet_2.cc_info.my_genesis_checker - - cc_1_hash = await cc_wallet_1.get_new_inner_hash() - cc_2_hash = await cc_wallet_2.get_new_inner_hash() - - tx_record = await cc_wallet_0.generate_signed_transaction([uint64(60), uint64(20)], [cc_1_hash, cc_2_hash]) - await wallet_0.wallet_state_manager.add_pending_transaction(tx_record) - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet_0.get_confirmed_balance, 20) - await time_out_assert(15, cc_wallet_0.get_unconfirmed_balance, 20) - - await time_out_assert(30, cc_wallet_1.get_confirmed_balance, 60) - await time_out_assert(30, cc_wallet_1.get_unconfirmed_balance, 60) - - await time_out_assert(30, cc_wallet_2.get_confirmed_balance, 20) - await time_out_assert(30, cc_wallet_2.get_unconfirmed_balance, 20) - - cc_hash = await cc_wallet_0.get_new_inner_hash() - - tx_record = await cc_wallet_1.generate_signed_transaction([uint64(15)], [cc_hash]) - await wallet_1.wallet_state_manager.add_pending_transaction(tx_record) - - tx_record_2 = await cc_wallet_2.generate_signed_transaction([uint64(20)], [cc_hash]) - await wallet_2.wallet_state_manager.add_pending_transaction(tx_record_2) - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record_2.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet_0.get_confirmed_balance, 55) - await time_out_assert(15, cc_wallet_0.get_unconfirmed_balance, 55) - - await time_out_assert(30, cc_wallet_1.get_confirmed_balance, 45) - await time_out_assert(30, cc_wallet_1.get_unconfirmed_balance, 45) - - await time_out_assert(30, cc_wallet_2.get_confirmed_balance, 0) - await time_out_assert(30, cc_wallet_2.get_unconfirmed_balance, 0) - - @pytest.mark.asyncio - async def test_cc_max_amount_send(self, two_wallet_nodes): - num_blocks = 3 - full_nodes, wallets = two_wallet_nodes - full_node_api = full_nodes[0] - full_node_server = full_node_api.server - wallet_node, server_2 = wallets[0] - wallet_node_2, server_3 = wallets[1] - wallet = wallet_node.wallet_state_manager.main_wallet - - ph = await wallet.get_new_puzzlehash() - - await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - funds = sum( - [ - calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) - for i in range(1, num_blocks - 1) - ] - ) - - await time_out_assert(15, wallet.get_confirmed_balance, funds) - - cc_wallet: CCWallet = await CCWallet.create_new_cc(wallet_node.wallet_state_manager, wallet, uint64(100000)) - tx_queue: List[TransactionRecord] = await wallet_node.wallet_state_manager.tx_store.get_not_sent() - tx_record = tx_queue[0] - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(32 * b"0")) - - await time_out_assert(15, cc_wallet.get_confirmed_balance, 100000) - await time_out_assert(15, cc_wallet.get_unconfirmed_balance, 100000) - - assert cc_wallet.cc_info.my_genesis_checker is not None - - cc_2_hash = await cc_wallet.get_new_inner_hash() - amounts = [] - puzzle_hashes = [] - for i in range(1, 50): - amounts.append(uint64(i)) - puzzle_hashes.append(cc_2_hash) - spent_coint = (await cc_wallet.get_cc_spendable_coins())[0].coin - tx_record = await cc_wallet.generate_signed_transaction(amounts, puzzle_hashes, coins={spent_coint}) - await wallet.wallet_state_manager.add_pending_transaction(tx_record) - - await time_out_assert( - 15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx_record.spend_bundle.name() - ) - - for i in range(1, num_blocks): - await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) - - await asyncio.sleep(2) - - async def check_all_there(): - spendable = await cc_wallet.get_cc_spendable_coins() - spendable_name_set = set() - for record in spendable: - spendable_name_set.add(record.coin.name()) - puzzle_hash = cc_puzzle_hash_for_inner_puzzle_hash(CC_MOD, cc_wallet.cc_info.my_genesis_checker, cc_2_hash) - for i in range(1, 50): - coin = Coin(spent_coint.name(), puzzle_hash, i) - if coin.name() not in spendable_name_set: - return False - return True - - await time_out_assert(15, check_all_there, True) - await asyncio.sleep(5) - max_sent_amount = await cc_wallet.get_max_send_amount() - - # 1) Generate transaction that is under the limit - under_limit_tx = None - try: - under_limit_tx = await cc_wallet.generate_signed_transaction( - [max_sent_amount - 1], - [ph], - ) - except ValueError: - assert ValueError - - assert under_limit_tx is not None - - # 2) Generate transaction that is equal to limit - at_limit_tx = None - try: - at_limit_tx = await cc_wallet.generate_signed_transaction( - [max_sent_amount], - [ph], - ) - except ValueError: - assert ValueError - - assert at_limit_tx is not None - - # 3) Generate transaction that is greater than limit - above_limit_tx = None - try: - above_limit_tx = await cc_wallet.generate_signed_transaction( - [max_sent_amount + 1], - [ph], - ) - except ValueError: - pass - - assert above_limit_tx is None diff --git a/tests/wallet/did_wallet/test_did.py b/tests/wallet/did_wallet/test_did.py index b1fb1ca793..69723a7d22 100644 --- a/tests/wallet/did_wallet/test_did.py +++ b/tests/wallet/did_wallet/test_did.py @@ -9,7 +9,7 @@ from chia.types.blockchain_format.program import Program from blspy import AugSchemeMPL from chia.types.spend_bundle import SpendBundle from chia.consensus.block_rewards import calculate_pool_reward, calculate_base_farmer_reward -from tests.time_out_assert import time_out_assert +from tests.time_out_assert import time_out_assert, time_out_assert_not_none pytestmark = pytest.mark.skip("TODO: Fix tests") @@ -444,7 +444,10 @@ class TestDIDWallet: test_info_list, test_message_spend_bundle, ) = await did_wallet_4.load_attest_files_for_recovery_spend(["test.attest"]) - await did_wallet_4.recovery_spend(coin, new_ph, test_info_list, pubkey, test_message_spend_bundle) + spend_bundle = await did_wallet_4.recovery_spend( + coin, new_ph, test_info_list, pubkey, test_message_spend_bundle + ) + await time_out_assert_not_none(15, full_node_1.full_node.mempool_manager.get_spendbundle, spend_bundle.name()) for i in range(1, num_blocks): await full_node_1.farm_new_transaction_block(FarmNewBlockProtocol(ph)) diff --git a/tests/wallet/rl_wallet/__init__.py b/tests/wallet/rl_wallet/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/wallet/rl_wallet/test_rl_rpc.py b/tests/wallet/rl_wallet/test_rl_rpc.py index 4b18addd28..3ba07a2a84 100644 --- a/tests/wallet/rl_wallet/test_rl_rpc.py +++ b/tests/wallet/rl_wallet/test_rl_rpc.py @@ -10,6 +10,7 @@ from chia.types.mempool_inclusion_status import MempoolInclusionStatus from chia.types.peer_info import PeerInfo from chia.util.bech32m import encode_puzzle_hash from chia.util.ints import uint16 +from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.wallet_types import WalletType from tests.setup_nodes import self_hostname, setup_simulators_and_wallets from tests.time_out_assert import time_out_assert @@ -27,7 +28,7 @@ async def is_transaction_in_mempool(user_wallet_id, api, tx_id: bytes32) -> bool val = await api.get_transaction({"wallet_id": user_wallet_id, "transaction_id": tx_id.hex()}) except ValueError: return False - for _, mis, _ in val["transaction"].sent_to: + for _, mis, _ in TransactionRecord.from_json_dict_convenience(val["transaction"]).sent_to: if ( MempoolInclusionStatus(mis) == MempoolInclusionStatus.SUCCESS or MempoolInclusionStatus(mis) == MempoolInclusionStatus.PENDING @@ -41,7 +42,7 @@ async def is_transaction_confirmed(user_wallet_id, api, tx_id: bytes32) -> bool: val = await api.get_transaction({"wallet_id": user_wallet_id, "transaction_id": tx_id.hex()}) except ValueError: return False - return val["transaction"].confirmed + return TransactionRecord.from_json_dict_convenience(val["transaction"]).confirmed async def check_balance(api, wallet_id): diff --git a/tests/wallet/rl_wallet/test_rl_wallet.py b/tests/wallet/rl_wallet/test_rl_wallet.py index 9836fbf82c..7a92cbe0eb 100644 --- a/tests/wallet/rl_wallet/test_rl_wallet.py +++ b/tests/wallet/rl_wallet/test_rl_wallet.py @@ -16,7 +16,7 @@ def event_loop(): yield loop -class TestCCWallet: +class TestCATWallet: @pytest.fixture(scope="function") async def two_wallet_nodes(self): async for _ in setup_simulators_and_wallets(1, 2, {}): diff --git a/tests/wallet/rpc/test_wallet_rpc.py b/tests/wallet/rpc/test_wallet_rpc.py index 4cd2983275..6849176eea 100644 --- a/tests/wallet/rpc/test_wallet_rpc.py +++ b/tests/wallet/rpc/test_wallet_rpc.py @@ -1,9 +1,15 @@ import asyncio from operator import attrgetter +from typing import Optional + +from blspy import G2Element + +from chia.types.coin_record import CoinRecord +from chia.types.coin_spend import CoinSpend +from chia.types.spend_bundle import SpendBundle from chia.util.config import load_config, save_config import logging -from pathlib import Path import pytest @@ -24,7 +30,6 @@ from chia.wallet.transaction_record import TransactionRecord from chia.wallet.transaction_sorting import SortKey from tests.setup_nodes import bt, setup_simulators_and_wallets, self_hostname from tests.time_out_assert import time_out_assert -from tests.util.rpc import validate_get_routes log = logging.getLogger(__name__) @@ -35,9 +40,14 @@ class TestWalletRpc: async for _ in setup_simulators_and_wallets(1, 2, {}): yield _ + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_make_transaction(self, two_wallet_nodes): + async def test_wallet_rpc(self, two_wallet_nodes, trusted): test_rpc_port = uint16(21529) + test_rpc_port_2 = uint16(21536) test_rpc_port_node = uint16(21530) num_blocks = 5 full_nodes, wallets = two_wallet_nodes @@ -51,6 +61,14 @@ class TestWalletRpc: ph_2 = await wallet_2.get_new_puzzlehash() await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) + + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + wallet_node_2.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} for i in range(0, num_blocks): await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) @@ -66,6 +84,7 @@ class TestWalletRpc: ) wallet_rpc_api = WalletRpcApi(wallet_node) + wallet_rpc_api_2 = WalletRpcApi(wallet_node_2) config = bt.config hostname = config["self_hostname"] @@ -96,14 +115,25 @@ class TestWalletRpc: config, connect_to_daemon=False, ) + rpc_cleanup_2 = await start_rpc_server( + wallet_rpc_api_2, + hostname, + daemon_port, + test_rpc_port_2, + stop_node_cb, + bt.root_path, + config, + connect_to_daemon=False, + ) await time_out_assert(5, wallet.get_confirmed_balance, initial_funds) await time_out_assert(5, wallet.get_unconfirmed_balance, initial_funds) client = await WalletRpcClient.create(self_hostname, test_rpc_port, bt.root_path, config) - await validate_get_routes(client, wallet_rpc_api) + client_2 = await WalletRpcClient.create(self_hostname, test_rpc_port_2, bt.root_path, config) client_node = await FullNodeRpcClient.create(self_hostname, test_rpc_port_node, bt.root_path, config) try: + await time_out_assert(5, client.get_synced) addr = encode_puzzle_hash(await wallet_node_2.wallet_state_manager.main_wallet.get_new_puzzlehash(), "xch") tx_amount = 15600000 try: @@ -113,7 +143,7 @@ class TestWalletRpc: pass # Tests sending a basic transaction - tx = await client.send_transaction("1", tx_amount, addr) + tx = await client.send_transaction("1", tx_amount, addr, memos=["this is a basic tx"]) transaction_id = tx.name async def tx_in_mempool(): @@ -131,6 +161,13 @@ class TestWalletRpc: async def eventual_balance(): return (await client.get_wallet_balance("1"))["confirmed_wallet_balance"] + # Checks that the memo can be retrieved + tx_confirmed = await client.get_transaction("1", transaction_id) + assert tx_confirmed.confirmed + assert len(tx_confirmed.get_memos()) == 1 + assert [b"this is a basic tx"] in tx_confirmed.get_memos().values() + assert list(tx_confirmed.get_memos().keys())[0] in [a.name() for a in tx.spend_bundle.additions()] + await time_out_assert(5, eventual_balance, initial_funds_eventually - tx_amount) # Tests offline signing @@ -160,7 +197,7 @@ class TestWalletRpc: # Test basic transaction to one output signed_tx_amount = 888000 tx_res: TransactionRecord = await client.create_signed_transaction( - [{"amount": signed_tx_amount, "puzzle_hash": ph_3}] + [{"amount": signed_tx_amount, "puzzle_hash": ph_3, "memos": ["My memo"]}] ) assert tx_res.fee_amount == 0 @@ -188,7 +225,7 @@ class TestWalletRpc: assert coin_to_spend is not None tx_res = await client.create_signed_transaction( - [{"amount": 444, "puzzle_hash": ph_4}, {"amount": 999, "puzzle_hash": ph_5}], + [{"amount": 444, "puzzle_hash": ph_4, "memos": ["hhh"]}, {"amount": 999, "puzzle_hash": ph_5}], coins=[coin_to_spend], fee=100, ) @@ -205,11 +242,29 @@ class TestWalletRpc: await client.farm_block(encode_puzzle_hash(ph_2, "xch")) await asyncio.sleep(0.5) + found: bool = False + for addition in tx_res.spend_bundle.additions(): + if addition.amount == 444: + cr: Optional[CoinRecord] = await client_node.get_coin_record_by_name(addition.name()) + assert cr is not None + spend: CoinSpend = await client_node.get_puzzle_and_solution( + addition.parent_coin_info, cr.confirmed_block_index + ) + sb: SpendBundle = SpendBundle([spend], G2Element()) + assert sb.get_memos() == {addition.name(): [b"hhh"]} + found = True + assert found + new_balance = initial_funds_eventually - tx_amount - signed_tx_amount - 444 - 999 - 100 await time_out_assert(5, eventual_balance, new_balance) send_tx_res: TransactionRecord = await client.send_transaction_multi( - "1", [{"amount": 555, "puzzle_hash": ph_4}, {"amount": 666, "puzzle_hash": ph_5}], fee=200 + "1", + [ + {"amount": 555, "puzzle_hash": ph_4, "memos": ["FiMemo"]}, + {"amount": 666, "puzzle_hash": ph_5, "memos": ["SeMemo"]}, + ], + fee=200, ) assert send_tx_res is not None assert send_tx_res.fee_amount == 200 @@ -230,11 +285,92 @@ class TestWalletRpc: new_balance = new_balance - 555 - 666 - 200 await time_out_assert(5, eventual_balance, new_balance) + # Checks that the memo can be retrieved + tx_confirmed = await client.get_transaction("1", send_tx_res.name) + assert tx_confirmed.confirmed + assert len(tx_confirmed.get_memos()) == 2 + print(tx_confirmed.get_memos()) + assert [b"FiMemo"] in tx_confirmed.get_memos().values() + assert [b"SeMemo"] in tx_confirmed.get_memos().values() + assert list(tx_confirmed.get_memos().keys())[0] in [a.name() for a in send_tx_res.spend_bundle.additions()] + assert list(tx_confirmed.get_memos().keys())[1] in [a.name() for a in send_tx_res.spend_bundle.additions()] + + ############## + # CATS # + ############## + + # Creates a wallet and a CAT with 20 mojos + res = await client.create_new_cat_and_wallet(20) + assert res["success"] + cat_0_id = res["wallet_id"] + asset_id = bytes.fromhex(res["asset_id"]) + assert len(asset_id) > 0 + + bal_0 = await client.get_wallet_balance(cat_0_id) + assert bal_0["confirmed_wallet_balance"] == 0 + assert bal_0["pending_coin_removal_count"] == 1 + col = await client.get_cat_asset_id(cat_0_id) + assert col == asset_id + assert (await client.get_cat_name(cat_0_id)) == "CAT Wallet" + await client.set_cat_name(cat_0_id, "My cat") + assert (await client.get_cat_name(cat_0_id)) == "My cat" + + await asyncio.sleep(1) + for i in range(0, 5): + await client.farm_block(encode_puzzle_hash(ph_2, "xch")) + await asyncio.sleep(0.5) + + bal_0 = await client.get_wallet_balance(cat_0_id) + assert bal_0["confirmed_wallet_balance"] == 20 + assert bal_0["pending_coin_removal_count"] == 0 + assert bal_0["unspent_coin_count"] == 1 + + # Creates a second wallet with the same CAT + res = await client_2.create_wallet_for_existing_cat(asset_id) + assert res["success"] + cat_1_id = res["wallet_id"] + asset_id_1 = bytes.fromhex(res["asset_id"]) + assert asset_id_1 == asset_id + + await asyncio.sleep(1) + for i in range(0, 5): + await client.farm_block(encode_puzzle_hash(ph_2, "xch")) + await asyncio.sleep(0.5) + bal_1 = await client_2.get_wallet_balance(cat_1_id) + assert bal_1["confirmed_wallet_balance"] == 0 + + addr_0 = await client.get_next_address(cat_0_id, False) + addr_1 = await client_2.get_next_address(cat_1_id, False) + + assert addr_0 != addr_1 + + await client.cat_spend(cat_0_id, 4, addr_1, 0, ["the cat memo"]) + + await asyncio.sleep(1) + for i in range(0, 5): + await client.farm_block(encode_puzzle_hash(ph_2, "xch")) + await asyncio.sleep(0.5) + + bal_0 = await client.get_wallet_balance(cat_0_id) + bal_1 = await client_2.get_wallet_balance(cat_1_id) + + assert bal_0["confirmed_wallet_balance"] == 16 + assert bal_1["confirmed_wallet_balance"] == 4 + + # Keys and addresses + address = await client.get_next_address("1", True) assert len(address) > 10 - transactions = await client.get_transactions("1") - assert len(transactions) > 1 + all_transactions = await client.get_transactions("1") + some_transactions = await client.get_transactions("1", 0, 5) + some_transactions_2 = await client.get_transactions("1", 5, 10) + assert len(all_transactions) > 1 + assert some_transactions == all_transactions[len(all_transactions) - 5 : len(all_transactions)] + assert some_transactions_2 == all_transactions[len(all_transactions) - 10 : len(all_transactions) - 5] + + transaction_count = await client.get_transaction_count("1") + assert transaction_count == len(all_transactions) all_transactions = await client.get_transactions("1") # Test transaction pagination @@ -297,8 +433,6 @@ class TestWalletRpc: await client.log_in_and_skip(pks[1]) sk_dict = await client.get_private_key(pks[1]) assert sk_dict["fingerprint"] == pks[1] - fingerprint = await client.get_logged_in_fingerprint() - assert fingerprint == pks[1] # Add in reward addresses into farmer and pool for testing delete key checks # set farmer to first private key @@ -341,25 +475,22 @@ class TestWalletRpc: balance = await client.get_wallet_balance(wallets[0]["id"]) assert balance["unconfirmed_wallet_balance"] == 0 - test_wallet_backup_path = Path("test_wallet_backup_file") - await client.create_backup(test_wallet_backup_path) - assert test_wallet_backup_path.exists() - test_wallet_backup_path.unlink() - try: await client.send_transaction(wallets[0]["id"], 100, addr) raise Exception("Should not create tx if no balance") except ValueError: pass - + # Delete all keys await client.delete_all_keys() - assert len(await client.get_public_keys()) == 0 finally: # Checks that the RPC manages to stop the node client.close() + client_2.close() client_node.close() await client.await_closed() + await client_2.await_closed() await client_node.await_closed() await rpc_cleanup() + await rpc_cleanup_2() await rpc_cleanup_node() diff --git a/tests/wallet/simple_sync/test_simple_sync_protocol.py b/tests/wallet/simple_sync/test_simple_sync_protocol.py index e4e00fde9e..b45ec1e136 100644 --- a/tests/wallet/simple_sync/test_simple_sync_protocol.py +++ b/tests/wallet/simple_sync/test_simple_sync_protocol.py @@ -25,7 +25,7 @@ from chia.wallet.wallet_state_manager import WalletStateManager from tests.connection_utils import add_dummy_connection from tests.setup_nodes import self_hostname, setup_simulators_and_wallets, bt from tests.time_out_assert import time_out_assert -from tests.wallet.cc_wallet.test_cc_wallet import tx_in_pool +from tests.wallet.cat_wallet.test_cat_wallet import tx_in_pool from tests.wallet_tools import WalletTool @@ -176,10 +176,18 @@ class TestSimpleSyncProtocol: await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(puzzle_hash)) funds = sum( - [calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) for i in range(1, num_blocks)] + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks + 1) + ] + ) + fn_amount = sum( + cr.coin.amount + for cr in await full_node_api.full_node.coin_store.get_coin_records_by_puzzle_hash(False, puzzle_hash) ) await time_out_assert(15, wallet.get_confirmed_balance, funds) + assert funds == fn_amount msg_1 = wallet_protocol.RegisterForPhUpdates([puzzle_hash], 0) msg_response_1 = await full_node_api.register_interest_in_puzzle_hash(msg_1, fake_wallet_peer) diff --git a/tests/wallet/simple_wallet/test_simple_wallet.py b/tests/wallet/simple_wallet/test_simple_wallet.py new file mode 100644 index 0000000000..f3a6906730 --- /dev/null +++ b/tests/wallet/simple_wallet/test_simple_wallet.py @@ -0,0 +1,87 @@ +import asyncio +import pytest +from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward +from chia.server.server import ChiaServer +from chia.simulator.simulator_protocol import FarmNewBlockProtocol +from chia.types.peer_info import PeerInfo +from chia.util.ints import uint16, uint32 +from chia.wallet.util.transaction_type import TransactionType +from chia.wallet.wallet_state_manager import WalletStateManager +from tests.setup_nodes import self_hostname, setup_simulators_and_wallets +from tests.time_out_assert import time_out_assert + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + + +class TestWalletSimulator: + @pytest.fixture(scope="function") + async def wallet_node(self): + async for _ in setup_simulators_and_wallets(1, 1, {}, simple_wallet=True): + yield _ + + @pytest.fixture(scope="function") + async def two_wallet_nodes(self): + async for _ in setup_simulators_and_wallets(1, 2, {}): + yield _ + + @pytest.fixture(scope="function") + async def two_wallet_nodes_five_freeze(self): + async for _ in setup_simulators_and_wallets(1, 2, {}): + yield _ + + @pytest.fixture(scope="function") + async def three_sim_two_wallets(self): + async for _ in setup_simulators_and_wallets(3, 2, {}): + yield _ + + @pytest.mark.asyncio + async def test_wallet_coinbase(self, wallet_node): + num_blocks = 10 + full_nodes, wallets = wallet_node + full_node_api = full_nodes[0] + server_1: ChiaServer = full_node_api.full_node.server + wallet_node, server_2 = wallets[0] + + wallet = wallet_node.wallet_state_manager.main_wallet + ph = await wallet.get_new_puzzlehash() + + await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), wallet_node.on_connect) + for i in range(0, num_blocks): + await full_node_api.farm_new_block(FarmNewBlockProtocol(ph)) + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) + + funds = sum( + [ + calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) + for i in range(1, num_blocks + 2) + ] + ) + + async def check_tx_are_pool_farm_rewards(): + wsm: WalletStateManager = wallet_node.wallet_state_manager + all_txs = await wsm.get_all_transactions(1) + expected_count = (num_blocks + 1) * 2 + if len(all_txs) != expected_count: + return False + pool_rewards = 0 + farm_rewards = 0 + + for tx in all_txs: + if tx.type == TransactionType.COINBASE_REWARD: + pool_rewards += 1 + elif tx.type == TransactionType.FEE_REWARD: + farm_rewards += 1 + + if pool_rewards != expected_count / 2: + return False + if farm_rewards != expected_count / 2: + return False + return True + + await time_out_assert(10, check_tx_are_pool_farm_rewards, True) + await time_out_assert(5, wallet.get_confirmed_balance, funds) diff --git a/tests/wallet/sync/test_wallet_sync.py b/tests/wallet/sync/test_wallet_sync.py index ce5ae216d2..5e5ba57840 100644 --- a/tests/wallet/sync/test_wallet_sync.py +++ b/tests/wallet/sync/test_wallet_sync.py @@ -16,7 +16,7 @@ from tests.time_out_assert import time_out_assert def wallet_height_at_least(wallet_node, h): - height = wallet_node.wallet_state_manager.blockchain._peak_height + height = wallet_node.wallet_state_manager.blockchain.get_peak_height() if height == h: return True return False @@ -47,14 +47,22 @@ class TestWalletSync: async for _ in setup_node_and_wallet(test_constants, starting_height=100): yield _ + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_basic_sync_wallet(self, wallet_node, default_400_blocks): + async def test_basic_sync_wallet(self, wallet_node, default_400_blocks, trusted): full_node_api, wallet_node, full_node_server, wallet_server = wallet_node for block in default_400_blocks: await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) # The second node should eventually catch up to the first one, and have the @@ -73,27 +81,79 @@ class TestWalletSync: 100, wallet_height_at_least, True, wallet_node, len(default_400_blocks) + num_blocks - 5 - 1 ) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_backtrack_sync_wallet(self, wallet_node, default_400_blocks): + async def test_almost_recent(self, wallet_node, default_1000_blocks, trusted): + # Tests the edge case of receiving funds right before the recent blocks in weight proof + full_node_api, wallet_node, full_node_server, wallet_server = wallet_node + for block in default_1000_blocks: + await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) + + wallet = wallet_node.wallet_state_manager.main_wallet + ph = await wallet.get_new_puzzlehash() + + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + + # Tests a reorg with the wallet + num_blocks = 20 + new_blocks = bt.get_consecutive_blocks( + num_blocks, block_list_input=default_1000_blocks, pool_reward_puzzle_hash=ph + ) + for i in range(1000, len(new_blocks)): + await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(new_blocks[i])) + + new_blocks = bt.get_consecutive_blocks( + test_constants.WEIGHT_PROOF_RECENT_BLOCKS + 10, block_list_input=new_blocks + ) + for i in range(1020, len(new_blocks)): + await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(new_blocks[i])) + + await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) + + await time_out_assert(30, wallet.get_confirmed_balance, 20 * calculate_pool_reward(1000)) + + @pytest.mark.parametrize( + "trusted", + [True, False], + ) + @pytest.mark.asyncio + async def test_backtrack_sync_wallet(self, wallet_node, default_400_blocks, trusted): full_node_api, wallet_node, full_node_server, wallet_server = wallet_node for block in default_400_blocks[:20]: await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) # The second node should eventually catch up to the first one, and have the # same tip at height num_blocks - 1. await time_out_assert(100, wallet_height_at_least, True, wallet_node, 19) - # Tests a reorg with the wallet + # Tests a reorg with the wallet + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_short_batch_sync_wallet(self, wallet_node, default_400_blocks): - + async def test_short_batch_sync_wallet(self, wallet_node, default_400_blocks, trusted): full_node_api, wallet_node, full_node_server, wallet_server = wallet_node for block in default_400_blocks[:200]: await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) @@ -102,13 +162,21 @@ class TestWalletSync: await time_out_assert(100, wallet_height_at_least, True, wallet_node, 199) # Tests a reorg with the wallet + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_long_sync_wallet(self, wallet_node, default_1000_blocks, default_400_blocks): + async def test_long_sync_wallet(self, wallet_node, default_1000_blocks, default_400_blocks, trusted): full_node_api, wallet_node, full_node_server, wallet_server = wallet_node for block in default_400_blocks: await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) + if trusted: + wallet_node.config["trusted_peers"] = {full_node_server.node_id: full_node_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) @@ -122,7 +190,7 @@ class TestWalletSync: for block in default_1000_blocks: await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) - log.info(f"wallet node height is {wallet_node.wallet_state_manager.blockchain._peak_height}") + log.info(f"wallet node height is {wallet_node.wallet_state_manager.blockchain.get_peak_height()}") await time_out_assert(600, wallet_height_at_least, True, wallet_node, len(default_1000_blocks) - 1) await disconnect_all_and_reconnect(wallet_server, full_node_server) @@ -138,8 +206,12 @@ class TestWalletSync: 600, wallet_height_at_least, True, wallet_node, len(default_1000_blocks) + num_blocks - 5 - 1 ) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_reorg_sync(self, wallet_node_simulator, default_400_blocks): + async def test_wallet_reorg_sync(self, wallet_node_simulator, default_400_blocks, trusted): num_blocks = 5 full_nodes, wallets = wallet_node_simulator full_node_api = full_nodes[0] @@ -149,6 +221,11 @@ class TestWalletSync: wallet = wsm.main_wallet ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {fn_server.node_id: fn_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo(self_hostname, uint16(fn_server._port)), None) # Insert 400 blocks @@ -182,8 +259,12 @@ class TestWalletSync: await time_out_assert(5, get_tx_count, 0, 1) await time_out_assert(5, wallet.get_confirmed_balance, 0) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_reorg_get_coinbase(self, wallet_node_simulator, default_400_blocks): + async def test_wallet_reorg_get_coinbase(self, wallet_node_simulator, default_400_blocks, trusted): full_nodes, wallets = wallet_node_simulator full_node_api = full_nodes[0] wallet_node, server_2 = wallets[0] @@ -192,6 +273,11 @@ class TestWalletSync: wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {fn_server.node_id: fn_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + await server_2.start_client(PeerInfo(self_hostname, uint16(fn_server._port)), None) # Insert 400 blocks diff --git a/tests/wallet/test_backup.py b/tests/wallet/test_backup.py new file mode 100644 index 0000000000..c02c70ba28 --- /dev/null +++ b/tests/wallet/test_backup.py @@ -0,0 +1,87 @@ +# import asyncio +# from pathlib import Path +# from secrets import token_bytes +# +# import pytest +# +# from chia.consensus.block_rewards import calculate_pool_reward, calculate_base_farmer_reward +# from chia.simulator.simulator_protocol import FarmNewBlockProtocol +# from chia.types.peer_info import PeerInfo +# from chia.util.ints import uint16, uint32, uint64 +# from tests.setup_nodes import setup_simulators_and_wallets +# from chia.wallet.cat_wallet.cat_wallet import CATWallet +# from tests.time_out_assert import time_out_assert +# +# +# @pytest.fixture(scope="module") +# def event_loop(): +# loop = asyncio.get_event_loop() +# yield loop +# +# +# class TestCATWalletBackup: +# @pytest.fixture(scope="function") +# async def two_wallet_nodes(self): +# async for _ in setup_simulators_and_wallets(1, 1, {}): +# yield _ +# +# @pytest.mark.asyncio +# async def test_coin_backup(self, two_wallet_nodes): +# num_blocks = 3 +# full_nodes, wallets = two_wallet_nodes +# full_node_api = full_nodes[0] +# full_node_server = full_node_api.full_node.server +# wallet_node, server_2 = wallets[0] +# wallet = wallet_node.wallet_state_manager.main_wallet +# +# ph = await wallet.get_new_puzzlehash() +# +# await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None) +# for i in range(1, num_blocks): +# await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) +# +# funds = sum( +# [ +# calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) +# for i in range(1, num_blocks - 1) +# ] +# ) +# +# await time_out_assert(15, wallet.get_confirmed_balance, funds) +# +# cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( +# wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100)) +# +# for i in range(1, num_blocks): +# await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) +# +# await time_out_assert(15, cat_wallet.get_confirmed_balance, 100) +# await time_out_assert(15, cat_wallet.get_unconfirmed_balance, 100) +# +# # Write backup to file +# filename = f"test-backup-{token_bytes(16).hex()}" +# file_path = Path(filename) +# await wallet_node.wallet_state_manager.create_wallet_backup(file_path) +# +# # Close wallet and restart +# db_path = wallet_node.wallet_state_manager.db_path +# wallet_node._close() +# await wallet_node._await_closed() +# +# db_path.unlink() +# +# started = await wallet_node._start() +# assert started is False +# +# await wallet_node._start(backup_file=file_path) +# +# await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), wallet_node.on_connect) +# +# all_wallets = wallet_node.wallet_state_manager.wallets +# assert len(all_wallets) == 2 +# +# cat_wallet_from_backup = wallet_node.wallet_state_manager.wallets[2] +# +# await time_out_assert(15, cat_wallet_from_backup.get_confirmed_balance, 100) +# if file_path.exists(): +# file_path.unlink() diff --git a/tests/wallet/test_puzzle_store.py b/tests/wallet/test_puzzle_store.py index 976c4f147d..94fa358c70 100644 --- a/tests/wallet/test_puzzle_store.py +++ b/tests/wallet/test_puzzle_store.py @@ -43,6 +43,7 @@ class TestPuzzleStore: AugSchemeMPL.key_gen(token_bytes(32)).get_g1(), WalletType.STANDARD_WALLET, uint32(1), + False, ) ) derivation_recs.append( @@ -52,6 +53,7 @@ class TestPuzzleStore: AugSchemeMPL.key_gen(token_bytes(32)).get_g1(), WalletType.RATE_LIMITED, uint32(2), + False, ) ) assert await db.puzzle_hash_exists(derivation_recs[0].puzzle_hash) is False @@ -61,7 +63,7 @@ class TestPuzzleStore: assert len((await db.get_all_puzzle_hashes())) == 0 assert await db.get_last_derivation_path() is None assert await db.get_unused_derivation_path() is None - assert await db.get_derivation_record(0, 2) is None + assert await db.get_derivation_record(0, 2, False) is None await db.add_derivation_paths(derivation_recs) @@ -87,7 +89,7 @@ class TestPuzzleStore: assert len((await db.get_all_puzzle_hashes())) == 2000 assert await db.get_last_derivation_path() == 999 assert await db.get_unused_derivation_path() == 0 - assert await db.get_derivation_record(0, 2) == derivation_recs[1] + assert await db.get_derivation_record(0, 2, False) == derivation_recs[1] # Indeces up to 250 await db.set_used_up_to(249) diff --git a/tests/wallet/test_singleton_lifecycle_fast.py b/tests/wallet/test_singleton_lifecycle_fast.py index 19e60f3de3..89af12c7b5 100644 --- a/tests/wallet/test_singleton_lifecycle_fast.py +++ b/tests/wallet/test_singleton_lifecycle_fast.py @@ -270,13 +270,7 @@ def launcher_conditions_and_spend_bundle( puzzle_db.add_puzzle(launcher_puzzle) launcher_puzzle_hash = launcher_puzzle.get_tree_hash() launcher_coin = Coin(parent_coin_id, launcher_puzzle_hash, launcher_amount) - # TODO: address hint error and remove ignore - # error: Argument 1 to "singleton_puzzle" has incompatible type "bytes32"; expected "Program" [arg-type] - singleton_full_puzzle = singleton_puzzle( - launcher_coin.name(), # type: ignore[arg-type] - launcher_puzzle_hash, - initial_singleton_inner_puzzle, - ) + singleton_full_puzzle = singleton_puzzle(launcher_coin.name(), launcher_puzzle_hash, initial_singleton_inner_puzzle) puzzle_db.add_puzzle(singleton_full_puzzle) singleton_full_puzzle_hash = singleton_full_puzzle.get_tree_hash() message_program = Program.to([singleton_full_puzzle_hash, launcher_amount, metadata]) @@ -433,13 +427,7 @@ def spend_coin_to_singleton( assert_coin_spent(coin_store, launcher_coin) assert_coin_spent(coin_store, farmed_coin) - # TODO: address hint error and remove ignore - # error: Argument 1 to "singleton_puzzle" has incompatible type "bytes32"; expected "Program" [arg-type] - singleton_expected_puzzle = singleton_puzzle( - launcher_id, # type: ignore[arg-type] - launcher_puzzle_hash, - initial_singleton_puzzle, - ) + singleton_expected_puzzle = singleton_puzzle(launcher_id, launcher_puzzle_hash, initial_singleton_puzzle) singleton_expected_puzzle_hash = singleton_expected_puzzle.get_tree_hash() expected_singleton_coin = Coin(launcher_coin.name(), singleton_expected_puzzle_hash, launcher_amount) assert_coin_spent(coin_store, expected_singleton_coin, is_spent=False) diff --git a/tests/wallet/test_wallet.py b/tests/wallet/test_wallet.py index 23a534c3c7..84913b2e2e 100644 --- a/tests/wallet/test_wallet.py +++ b/tests/wallet/test_wallet.py @@ -16,7 +16,7 @@ from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_state_manager import WalletStateManager from tests.setup_nodes import self_hostname, setup_simulators_and_wallets from tests.time_out_assert import time_out_assert, time_out_assert_not_none -from tests.wallet.cc_wallet.test_cc_wallet import tx_in_pool +from tests.wallet.cat_wallet.test_cat_wallet import tx_in_pool @pytest.fixture(scope="module") @@ -28,7 +28,7 @@ def event_loop(): class TestWalletSimulator: @pytest.fixture(scope="function") async def wallet_node(self): - async for _ in setup_simulators_and_wallets(1, 1, {}): + async for _ in setup_simulators_and_wallets(1, 1, {}, True): yield _ @pytest.fixture(scope="function") @@ -38,21 +38,25 @@ class TestWalletSimulator: @pytest.fixture(scope="function") async def two_wallet_nodes(self): - async for _ in setup_simulators_and_wallets(1, 2, {}): + async for _ in setup_simulators_and_wallets(1, 2, {}, True): yield _ @pytest.fixture(scope="function") async def two_wallet_nodes_five_freeze(self): - async for _ in setup_simulators_and_wallets(1, 2, {}): + async for _ in setup_simulators_and_wallets(1, 2, {}, True): yield _ @pytest.fixture(scope="function") async def three_sim_two_wallets(self): - async for _ in setup_simulators_and_wallets(3, 2, {}): + async for _ in setup_simulators_and_wallets(3, 2, {}, True): yield _ + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_coinbase(self, wallet_node): + async def test_wallet_coinbase(self, wallet_node, trusted): num_blocks = 10 full_nodes, wallets = wallet_node full_node_api = full_nodes[0] @@ -61,6 +65,10 @@ class TestWalletSimulator: wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {server_1.node_id: server_1.node_id} + else: + wallet_node.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None) for i in range(0, num_blocks): @@ -99,8 +107,12 @@ class TestWalletSimulator: await time_out_assert(10, check_tx_are_pool_farm_rewards, True) await time_out_assert(5, wallet.get_confirmed_balance, funds) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_make_transaction(self, two_wallet_nodes): + async def test_wallet_make_transaction(self, two_wallet_nodes, trusted): num_blocks = 5 full_nodes, wallets = two_wallet_nodes full_node_api = full_nodes[0] @@ -109,6 +121,12 @@ class TestWalletSimulator: wallet_node_2, server_3 = wallets[1] wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {server_1.node_id: server_1.node_id} + wallet_node_2.config["trusted_peers"] = {server_1.node_id: server_1.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(server_1._port)), None) @@ -146,8 +164,12 @@ class TestWalletSimulator: await time_out_assert(5, wallet.get_confirmed_balance, new_funds - 10) await time_out_assert(5, wallet.get_unconfirmed_balance, new_funds - 10) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_coinbase_reorg(self, wallet_node): + async def test_wallet_coinbase_reorg(self, wallet_node, trusted): num_blocks = 5 full_nodes, wallets = wallet_node full_node_api = full_nodes[0] @@ -155,7 +177,10 @@ class TestWalletSimulator: wallet_node, server_2 = wallets[0] wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - + if trusted: + wallet_node.config["trusted_peers"] = {fn_server.node_id: fn_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(fn_server._port)), None) for i in range(0, num_blocks): await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph)) @@ -177,8 +202,12 @@ class TestWalletSimulator: await time_out_assert(5, wallet.get_confirmed_balance, funds) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_send_to_three_peers(self, three_sim_two_wallets): + async def test_wallet_send_to_three_peers(self, three_sim_two_wallets, trusted): num_blocks = 10 full_nodes, wallets = three_sim_two_wallets @@ -196,6 +225,10 @@ class TestWalletSimulator: server_2 = full_node_2.server ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash() + if trusted: + wallet_0.config["trusted_peers"] = {server_0.node_id: server_0.node_id} + else: + wallet_0.config["trusted_peers"] = {} # wallet0 <-> sever0 await wallet_server_0.start_client(PeerInfo(self_hostname, uint16(server_0._port)), None) @@ -223,15 +256,19 @@ class TestWalletSimulator: # wallet0 <-> sever1 await wallet_server_0.start_client(PeerInfo(self_hostname, uint16(server_1._port)), wallet_0.on_connect) - await time_out_assert_not_none(5, full_node_1.mempool_manager.get_spendbundle, tx.spend_bundle.name()) + await time_out_assert_not_none(15, full_node_1.mempool_manager.get_spendbundle, tx.spend_bundle.name()) # wallet0 <-> sever2 await wallet_server_0.start_client(PeerInfo(self_hostname, uint16(server_2._port)), wallet_0.on_connect) - await time_out_assert_not_none(5, full_node_2.mempool_manager.get_spendbundle, tx.spend_bundle.name()) + await time_out_assert_not_none(15, full_node_2.mempool_manager.get_spendbundle, tx.spend_bundle.name()) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_make_transaction_hop(self, two_wallet_nodes_five_freeze): + async def test_wallet_make_transaction_hop(self, two_wallet_nodes_five_freeze, trusted): num_blocks = 10 full_nodes, wallets = two_wallet_nodes_five_freeze full_node_api_0 = full_nodes[0] @@ -243,7 +280,12 @@ class TestWalletSimulator: wallet_0 = wallet_node_0.wallet_state_manager.main_wallet wallet_1 = wallet_node_1.wallet_state_manager.main_wallet ph = await wallet_0.get_new_puzzlehash() - + if trusted: + wallet_node_0.config["trusted_peers"] = {server_0.node_id: server_0.node_id} + wallet_node_1.config["trusted_peers"] = {server_0.node_id: server_0.node_id} + else: + wallet_node_0.config["trusted_peers"] = {} + wallet_node_1.config["trusted_peers"] = {} await wallet_0_server.start_client(PeerInfo(self_hostname, uint16(server_0._port)), None) await wallet_1_server.start_client(PeerInfo(self_hostname, uint16(server_0._port)), None) @@ -341,9 +383,12 @@ class TestWalletSimulator: # True, # ) # await _teardown_nodes(node_iters) - + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_make_transaction_with_fee(self, two_wallet_nodes): + async def test_wallet_make_transaction_with_fee(self, two_wallet_nodes, trusted): num_blocks = 5 full_nodes, wallets = two_wallet_nodes full_node_1 = full_nodes[0] @@ -351,7 +396,16 @@ class TestWalletSimulator: wallet_node_2, server_3 = wallets[1] wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - + if trusted: + wallet_node.config["trusted_peers"] = { + full_node_1.full_node.server.node_id: full_node_1.full_node.server.node_id + } + wallet_node_2.config["trusted_peers"] = { + full_node_1.full_node.server.node_id: full_node_1.full_node.server.node_id + } + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(full_node_1.full_node.server._port)), None) for i in range(0, num_blocks): @@ -396,8 +450,12 @@ class TestWalletSimulator: await time_out_assert(5, wallet.get_confirmed_balance, new_funds - tx_amount - tx_fee) await time_out_assert(5, wallet.get_unconfirmed_balance, new_funds - tx_amount - tx_fee) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_create_hit_max_send_amount(self, two_wallet_nodes): + async def test_wallet_create_hit_max_send_amount(self, two_wallet_nodes, trusted): num_blocks = 5 full_nodes, wallets = two_wallet_nodes full_node_1 = full_nodes[0] @@ -405,7 +463,16 @@ class TestWalletSimulator: wallet_node_2, server_3 = wallets[1] wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - + if trusted: + wallet_node.config["trusted_peers"] = { + full_node_1.full_node.server.node_id: full_node_1.full_node.server.node_id + } + wallet_node_2.config["trusted_peers"] = { + full_node_1.full_node.server.node_id: full_node_1.full_node.server.node_id + } + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(full_node_1.full_node.server._port)), None) for i in range(0, num_blocks): @@ -479,8 +546,12 @@ class TestWalletSimulator: assert above_limit_tx is None + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_prevent_fee_theft(self, two_wallet_nodes): + async def test_wallet_prevent_fee_theft(self, two_wallet_nodes, trusted): num_blocks = 5 full_nodes, wallets = two_wallet_nodes full_node_1 = full_nodes[0] @@ -488,7 +559,16 @@ class TestWalletSimulator: wallet_node_2, server_3 = wallets[1] wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - + if trusted: + wallet_node.config["trusted_peers"] = { + full_node_1.full_node.server.node_id: full_node_1.full_node.server.node_id + } + wallet_node_2.config["trusted_peers"] = { + full_node_1.full_node.server.node_id: full_node_1.full_node.server.node_id + } + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(full_node_1.full_node.server._port)), None) for i in range(0, num_blocks): @@ -537,6 +617,7 @@ class TestWalletSimulator: trade_id=None, type=uint32(TransactionType.OUTGOING_TX.value), name=name, + memos=list(stolen_sb.get_memos().items()), ) await wallet.push_transaction(stolen_tx) @@ -548,11 +629,14 @@ class TestWalletSimulator: # Funds have not decreased because stolen_tx was rejected outstanding_coinbase_rewards = 2000000000000 - await time_out_assert(5, wallet.get_confirmed_balance, funds + outstanding_coinbase_rewards) - await time_out_assert(5, wallet.get_confirmed_balance, funds + outstanding_coinbase_rewards) + await time_out_assert(20, wallet.get_confirmed_balance, funds + outstanding_coinbase_rewards) + @pytest.mark.parametrize( + "trusted", + [True, False], + ) @pytest.mark.asyncio - async def test_wallet_tx_reorg(self, two_wallet_nodes): + async def test_wallet_tx_reorg(self, two_wallet_nodes, trusted): num_blocks = 5 full_nodes, wallets = two_wallet_nodes full_node_api = full_nodes[0] @@ -565,6 +649,12 @@ class TestWalletSimulator: ph = await wallet.get_new_puzzlehash() ph2 = await wallet_2.get_new_puzzlehash() + if trusted: + wallet_node.config["trusted_peers"] = {fn_server.node_id: fn_server.node_id} + wallet_node_2.config["trusted_peers"] = {fn_server.node_id: fn_server.node_id} + else: + wallet_node.config["trusted_peers"] = {} + wallet_node_2.config["trusted_peers"] = {} await server_2.start_client(PeerInfo(self_hostname, uint16(fn_server._port)), None) await server_3.start_client(PeerInfo(self_hostname, uint16(fn_server._port)), None) diff --git a/tests/wallet/test_wallet_blockchain.py b/tests/wallet/test_wallet_blockchain.py new file mode 100644 index 0000000000..39cbaec540 --- /dev/null +++ b/tests/wallet/test_wallet_blockchain.py @@ -0,0 +1,117 @@ +import asyncio +import dataclasses +from pathlib import Path + +import aiosqlite +import pytest + +from chia.consensus.blockchain import ReceiveBlockResult +from chia.protocols import full_node_protocol +from chia.types.blockchain_format.vdf import VDFProof +from chia.types.weight_proof import WeightProof +from chia.util.db_wrapper import DBWrapper +from chia.util.generator_tools import get_block_header +from chia.wallet.key_val_store import KeyValStore +from chia.wallet.wallet_blockchain import WalletBlockchain +from tests.setup_nodes import test_constants, setup_node_and_wallet + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + + +class TestWalletBlockchain: + @pytest.fixture(scope="function") + async def wallet_node(self): + async for _ in setup_node_and_wallet(test_constants): + yield _ + + @pytest.mark.asyncio + async def test_wallet_blockchain(self, wallet_node, default_1000_blocks): + full_node_api, wallet_node, full_node_server, wallet_server = wallet_node + + for block in default_1000_blocks[:600]: + await full_node_api.full_node.respond_block(full_node_protocol.RespondBlock(block)) + + res = await full_node_api.request_proof_of_weight( + full_node_protocol.RequestProofOfWeight( + default_1000_blocks[499].height + 1, default_1000_blocks[499].header_hash + ) + ) + res_2 = await full_node_api.request_proof_of_weight( + full_node_protocol.RequestProofOfWeight( + default_1000_blocks[460].height + 1, default_1000_blocks[460].header_hash + ) + ) + + res_3 = await full_node_api.request_proof_of_weight( + full_node_protocol.RequestProofOfWeight( + default_1000_blocks[505].height + 1, default_1000_blocks[505].header_hash + ) + ) + weight_proof: WeightProof = full_node_protocol.RespondProofOfWeight.from_bytes(res.data).wp + weight_proof_short: WeightProof = full_node_protocol.RespondProofOfWeight.from_bytes(res_2.data).wp + weight_proof_long: WeightProof = full_node_protocol.RespondProofOfWeight.from_bytes(res_3.data).wp + + db_filename = Path("wallet_store_test.db") + + if db_filename.exists(): + db_filename.unlink() + + db_connection = await aiosqlite.connect(db_filename) + db_wrapper = DBWrapper(db_connection) + store = await KeyValStore.create(db_wrapper) + chain = await WalletBlockchain.create( + store, test_constants, wallet_node.wallet_state_manager.weight_proof_handler + ) + try: + assert (await chain.get_peak_block()) is None + assert chain.get_peak_height() == 0 + assert chain.get_latest_timestamp() == 0 + + await chain.new_weight_proof(weight_proof) + assert (await chain.get_peak_block()) is not None + assert chain.get_peak_height() == 499 + assert chain.get_latest_timestamp() > 0 + + await chain.new_weight_proof(weight_proof_short) + assert chain.get_peak_height() == 499 + + await chain.new_weight_proof(weight_proof_long) + assert chain.get_peak_height() == 505 + + header_blocks = [] + for block in default_1000_blocks: + header_block = get_block_header(block, [], []) + header_blocks.append(header_block) + + res, err = await chain.receive_block(header_blocks[50]) + print(res, err) + assert res == ReceiveBlockResult.DISCONNECTED_BLOCK + + res, err = await chain.receive_block(header_blocks[400]) + print(res, err) + assert res == ReceiveBlockResult.ALREADY_HAVE_BLOCK + + res, err = await chain.receive_block(header_blocks[507]) + print(res, err) + assert res == ReceiveBlockResult.DISCONNECTED_BLOCK + + res, err = await chain.receive_block( + dataclasses.replace(header_blocks[506], challenge_chain_ip_proof=VDFProof(2, b"123", True)) + ) + assert res == ReceiveBlockResult.INVALID_BLOCK + + assert chain.get_peak_height() == 505 + + for block in header_blocks[506:]: + res, err = await chain.receive_block(block) + assert res == ReceiveBlockResult.NEW_PEAK + assert chain.get_peak_height() == block.height + + assert chain.get_peak_height() == 999 + finally: + await db_connection.close() + db_filename.unlink() diff --git a/tests/wallet/test_wallet_key_val_store.py b/tests/wallet/test_wallet_key_val_store.py new file mode 100644 index 0000000000..b653091511 --- /dev/null +++ b/tests/wallet/test_wallet_key_val_store.py @@ -0,0 +1,58 @@ +import asyncio +from pathlib import Path +import aiosqlite +import pytest + +from chia.types.full_block import FullBlock +from chia.types.header_block import HeaderBlock +from chia.util.db_wrapper import DBWrapper +from chia.wallet.key_val_store import KeyValStore +from tests.setup_nodes import bt + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + + +class TestWalletKeyValStore: + @pytest.mark.asyncio + async def test_store(self): + db_filename = Path("wallet_store_test.db") + + if db_filename.exists(): + db_filename.unlink() + + db_connection = await aiosqlite.connect(db_filename) + db_wrapper = DBWrapper(db_connection) + store = await KeyValStore.create(db_wrapper) + try: + blocks = bt.get_consecutive_blocks(20) + block: FullBlock = blocks[0] + block_2: FullBlock = blocks[1] + + assert (await store.get_object("a", FullBlock)) is None + await store.set_object("a", block) + assert await store.get_object("a", FullBlock) == block + await store.set_object("a", block) + assert await store.get_object("a", FullBlock) == block + await store.set_object("a", block_2) + await store.set_object("a", block_2) + assert await store.get_object("a", FullBlock) == block_2 + await store.remove_object("a") + assert (await store.get_object("a", FullBlock)) is None + + for block in blocks: + assert (await store.get_object(block.header_hash.hex(), FullBlock)) is None + await store.set_object(block.header_hash.hex(), block) + assert (await store.get_object(block.header_hash.hex(), FullBlock)) == block + + # Wrong type + await store.set_object("a", block_2) + with pytest.raises(Exception): + await store.get_object("a", HeaderBlock) + + finally: + await db_connection.close() + db_filename.unlink()