NFT transfer/minting commands now validate the specified addresses (#12520)

* WIP to introduce a wallet AddressType enum to help validate
user-entered addresses for XCH/DID/NFTs etc.

* Support grabbing the current network's address prefix while also
allowing the current value to be updated (in case the config changes)

* Added tests for functionality in/used-by AddressType

* Removed code that would automatically resolve the selected network's address prefix.
is_valid_address() and ensure_valid_address() now require the allowed address types to be passed in.

* Check the length of the decoded address data

* Removed TXCH address type. "txch" address validation now requires a config to be passed into one of the validation functions.

* config is now required when using the address validation functions or calling AddressType.hrp().
root_path_and_config_with_address_prefix --> config_with_address_prefix fixture change.

* Update tests/conftest.py

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Update tests/conftest.py

Co-authored-by: Kyle Altendorf <sda@fstab.net>

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
Jeff 2022-07-29 10:22:33 -07:00 committed by GitHub
parent c62a17358a
commit 4b8130ec60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 334 additions and 27 deletions

View File

@ -15,15 +15,15 @@ from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.server.start_wallet import SERVICE_NAME
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.bech32m import bech32_decode, decode_puzzle_hash, encode_puzzle_hash
from chia.util.config import load_config
from chia.util.config import load_config, selected_network_address_prefix
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.did_wallet.did_info import DID_HRP
from chia.wallet.nft_wallet.nft_info import NFT_HRP, NFTInfo
from chia.wallet.nft_wallet.nft_info import NFTInfo
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import Offer
from chia.wallet.trading.trade_status import TradeStatus
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.address_type import AddressType, ensure_valid_address
from chia.wallet.util.transaction_type import TransactionType
from chia.wallet.util.wallet_types import WalletType
@ -101,7 +101,7 @@ async def get_name_for_wallet_id(
async def get_transaction(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
transaction_id = bytes32.from_hexstr(args["tx_id"])
config = load_config(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
address_prefix = selected_network_address_prefix(config)
tx: TransactionRecord = await wallet_client.get_transaction("this is unused", transaction_id=transaction_id)
try:
@ -141,7 +141,7 @@ async def get_transactions(args: dict, wallet_client: WalletRpcClient, fingerpri
)
config = load_config(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
address_prefix = selected_network_address_prefix(config)
if len(txs) == 0:
print("There are no transactions to this address")
@ -586,7 +586,7 @@ async def print_balances(args: dict, wallet_client: WalletRpcClient, fingerprint
wallet_type = WalletType(args["type"])
summaries_response = await wallet_client.get_wallets(wallet_type)
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
address_prefix = selected_network_address_prefix(config)
is_synced: bool = await wallet_client.get_synced()
is_syncing: bool = await wallet_client.get_sync_status()
@ -784,8 +784,17 @@ async def create_nft_wallet(args: Dict, wallet_client: WalletRpcClient, fingerpr
async def mint_nft(args: Dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
wallet_id = args["wallet_id"]
royalty_address = args["royalty_address"]
target_address = args["target_address"]
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
royalty_address = (
None
if not args["royalty_address"]
else ensure_valid_address(args["royalty_address"], allowed_types={AddressType.XCH}, config=config)
)
target_address = (
None
if not args["target_address"]
else ensure_valid_address(args["target_address"], allowed_types={AddressType.XCH}, config=config)
)
no_did_ownership = args["no_did_ownership"]
hash = args["hash"]
uris = args["uris"]
@ -866,7 +875,8 @@ async def transfer_nft(args: Dict, wallet_client: WalletRpcClient, fingerprint:
try:
wallet_id = args["wallet_id"]
nft_coin_id = args["nft_coin_id"]
target_address = args["target_address"]
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
target_address = ensure_valid_address(args["target_address"], allowed_types={AddressType.XCH}, config=config)
fee: int = int(Decimal(args["fee"]) * units["chia"])
response = await wallet_client.transfer_nft(wallet_id, nft_coin_id, target_address, fee)
spend_bundle = response["spend_bundle"]
@ -875,11 +885,11 @@ async def transfer_nft(args: Dict, wallet_client: WalletRpcClient, fingerprint:
print(f"Failed to transfer NFT: {e}")
def print_nft_info(nft: NFTInfo) -> None:
def print_nft_info(nft: NFTInfo, *, config: Dict[str, Any]) -> None:
indent: str = " "
owner_did = None if nft.owner_did is None else encode_puzzle_hash(nft.owner_did, DID_HRP)
owner_did = None if nft.owner_did is None else encode_puzzle_hash(nft.owner_did, AddressType.DID.hrp(config))
print()
print(f"{'NFT identifier:'.ljust(26)} {encode_puzzle_hash(nft.launcher_id, NFT_HRP)}")
print(f"{'NFT identifier:'.ljust(26)} {encode_puzzle_hash(nft.launcher_id, AddressType.NFT.hrp(config))}")
print(f"{'Launcher coin ID:'.ljust(26)} {nft.launcher_id}")
print(f"{'Launcher puzhash:'.ljust(26)} {nft.launcher_puzhash}")
print(f"{'Current NFT coin ID:'.ljust(26)} {nft.nft_coin_id}")
@ -913,12 +923,13 @@ def print_nft_info(nft: NFTInfo) -> None:
async def list_nfts(args: Dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
wallet_id = args["wallet_id"]
try:
config = load_config(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
response = await wallet_client.list_nfts(wallet_id)
nft_list = response["nft_list"]
if len(nft_list) > 0:
for n in nft_list:
nft = NFTInfo.from_json_dict(n)
print_nft_info(nft)
print_nft_info(nft, config=config)
else:
print(f"No NFTs found for wallet with id {wallet_id} on key {fingerprint}")
except Exception as e:
@ -941,8 +952,9 @@ async def set_nft_did(args: Dict, wallet_client: WalletRpcClient, fingerprint: i
async def get_nft_info(args: Dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
nft_coin_id = args["nft_coin_id"]
try:
config = load_config(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
response = await wallet_client.get_nft_info(nft_coin_id)
nft_info = NFTInfo.from_json_dict(response["nft_info"])
print_nft_info(nft_info)
print_nft_info(nft_info, config=config)
except Exception as e:
print(f"Failed to get NFT info: {e}")

View File

@ -273,3 +273,8 @@ def override_config(config: Dict[str, Any], config_overrides: Optional[Dict[str,
for k, v in config_overrides.items():
add_property(new_config, k, v)
return new_config
def selected_network_address_prefix(config: Dict[str, Any]) -> str:
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
return address_prefix

View File

@ -0,0 +1,53 @@
from enum import Enum
from typing import Any, Dict, Set
from chia.util.bech32m import bech32_decode, convertbits
from chia.util.config import selected_network_address_prefix
class AddressType(Enum):
XCH = "xch"
NFT = "nft"
DID = "did:chia:"
def hrp(self, config: Dict[str, Any]) -> str:
if self == AddressType.XCH:
# Special case to map XCH to the current network's address prefix
return selected_network_address_prefix(config)
return self.value
def expected_decoded_length(self) -> int:
# Current address types encode 32 bytes. If future address types vary in
# their length, this will need to be updated.
return 32
def is_valid_address(address: str, allowed_types: Set[AddressType], config: Dict[str, Any]) -> bool:
try:
ensure_valid_address(address, allowed_types=allowed_types, config=config)
return True
except ValueError:
return False
def ensure_valid_address(address: str, *, allowed_types: Set[AddressType], config: Dict[str, Any]) -> str:
hrp, b32data = bech32_decode(address)
if not b32data or not hrp:
raise ValueError(f"Invalid address: {address}")
# Match by prefix (hrp) and return the corresponding address type
address_type = next(
(addr_type for (addr_type, addr_hrp) in ((a, a.hrp(config)) for a in allowed_types) if addr_hrp == hrp),
None,
)
if address_type is None:
raise ValueError(
f"Invalid address: {address}. "
f"Expected an address with one of the following prefixes: {[t.hrp(config) for t in allowed_types]}"
)
decoded_data = convertbits(b32data, 5, 8, False)
if len(decoded_data) != address_type.expected_decoded_length():
raise ValueError(
f"Invalid address: {address}. "
f"Expected {address_type.expected_decoded_length()} bytes, got {len(decoded_data)}"
)
return address

View File

@ -7,7 +7,9 @@ import pytest
import pytest_asyncio
import tempfile
from typing import AsyncIterator, List, Tuple
from tests.setup_nodes import setup_node_and_wallet, setup_n_nodes, setup_two_nodes
from pathlib import Path
from typing import Any, AsyncIterator, Dict, List, Tuple
from chia.server.start_service import Service
# Set spawn after stdlib imports, but before other imports
@ -16,6 +18,7 @@ from chia.protocols import full_node_protocol
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.config import create_default_chia_config, lock_and_load_config
from chia.util.ints import uint16
from tests.core.node_height import node_height_at_least
from tests.pools.test_pool_rpc import wallet_is_synced
@ -608,3 +611,33 @@ async def setup_sim():
sim_client = SimClient(sim)
await sim.farm_block()
return sim, sim_client
@pytest.fixture(scope="function")
def tmp_chia_root(tmp_path):
"""
Create a temp directory and populate it with an empty chia_root directory.
"""
path: Path = tmp_path / "chia_root"
path.mkdir(parents=True, exist_ok=True)
return path
@pytest.fixture(scope="function")
def root_path_populated_with_config(tmp_chia_root) -> Path:
"""
Create a temp chia_root directory and populate it with a default config.yaml.
Returns the chia_root path.
"""
root_path: Path = tmp_chia_root
create_default_chia_config(root_path)
return root_path
@pytest.fixture(scope="function")
def config_with_address_prefix(root_path_populated_with_config: Path, prefix: str) -> Dict[str, Any]:
updated_config: Dict[str, Any] = {}
with lock_and_load_config(root_path_populated_with_config, "config.yaml") as config:
if prefix is not None:
config["network_overrides"]["config"][config["selected_network"]]["address_prefix"] = prefix
return config

View File

@ -16,12 +16,13 @@ from chia.util.config import (
lock_and_load_config,
lock_config,
save_config,
selected_network_address_prefix,
)
from multiprocessing import Pool, Queue, TimeoutError
from pathlib import Path
from threading import Thread
from time import sleep
from typing import Dict, Optional
from typing import Any, Dict, Optional
# Commented-out lines are preserved to aid in debugging the multiprocessing tests
@ -140,17 +141,6 @@ def run_reader_and_writer_tasks(root_path: Path, default_config: Dict):
asyncio.run(create_reader_and_writer_tasks(root_path, default_config))
@pytest.fixture(scope="function")
def root_path_populated_with_config(tmpdir) -> Path:
"""
Create a temp directory and populate it with a default config.yaml.
Returns the root path containing the config.
"""
root_path: Path = Path(tmpdir)
create_default_chia_config(root_path)
return Path(root_path)
@pytest.fixture(scope="function")
def default_config_dict() -> Dict:
"""
@ -313,3 +303,30 @@ class TestConfig:
)
)
await asyncio.gather(*all_tasks)
@pytest.mark.parametrize("prefix", [None])
def test_selected_network_address_prefix_default_config(self, config_with_address_prefix: Dict[str, Any]) -> None:
"""
Temp config.yaml created using a default config. address_prefix is defaulted to "xch"
"""
config = config_with_address_prefix
prefix = selected_network_address_prefix(config)
assert prefix == "xch"
@pytest.mark.parametrize("prefix", ["txch"])
def test_selected_network_address_prefix_testnet_config(self, config_with_address_prefix: Dict[str, Any]) -> None:
"""
Temp config.yaml created using a modified config. address_prefix is set to "txch"
"""
config = config_with_address_prefix
prefix = selected_network_address_prefix(config)
assert prefix == "txch"
def test_selected_network_address_prefix_config_dict(self, default_config_dict: Dict[str, Any]) -> None:
"""
Modified config dictionary has address_prefix set to "customxch"
"""
config = default_config_dict
config["network_overrides"]["config"][config["selected_network"]]["address_prefix"] = "customxch"
prefix = selected_network_address_prefix(config)
assert prefix == "customxch"

View File

@ -0,0 +1,187 @@
from typing import Any, Dict
import pytest
from chia.wallet.util.address_type import AddressType, ensure_valid_address, is_valid_address
@pytest.mark.parametrize("prefix", [None])
def test_xch_hrp_for_default_config(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
assert AddressType.XCH.hrp(config) == "xch"
@pytest.mark.parametrize("prefix", ["txch"])
def test_txch_hrp_for_testnet(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
assert AddressType.XCH.hrp(config) == "txch"
@pytest.mark.parametrize("prefix", [None])
def test_is_valid_address_xch(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"xch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs8taffd", allowed_types={AddressType.XCH}, config=config
)
assert valid is True
@pytest.mark.parametrize("prefix", ["txch"])
def test_is_valid_address_txch(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
# TXCH address validation requires a config
valid = is_valid_address(
"txch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs2v6lg7",
allowed_types={AddressType.XCH},
config=config,
)
assert valid is True
@pytest.mark.parametrize("prefix", [None])
def test_is_valid_address_xch_bad_address(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"xch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs8xxxxx", allowed_types={AddressType.XCH}, config=config
)
assert valid is False
@pytest.mark.parametrize("prefix", [None])
def test_is_valid_address_nft(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"nft1mx2nkvml2eekjtqwdmxvmf3js8g083hpszzhkhtwvhcss8efqzhqtza773", allowed_types={AddressType.NFT}, config=config
)
assert valid is True
@pytest.mark.parametrize("prefix", ["txch"])
def test_is_valid_address_nft_with_testnet(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"nft1mx2nkvml2eekjtqwdmxvmf3js8g083hpszzhkhtwvhcss8efqzhqtza773", allowed_types={AddressType.NFT}, config=config
)
assert valid is True
@pytest.mark.parametrize("prefix", [None])
def test_is_valid_address_nft_bad_address(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"nft1mx2nkvml2eekjtqwdmxvmf3js8g083hpszzhkhtwvhcss8efqzhqtxxxxx", allowed_types={AddressType.NFT}, config=config
)
assert valid is False
@pytest.mark.parametrize("prefix", [None])
def test_is_valid_address_did(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"did:chia:14jxdtqcyp3gk8ka0678eq8mmtnktgpmp2vuqq3vtsl2e5qr7fyrsr9gsr7",
allowed_types={AddressType.DID},
config=config,
)
assert valid is True
@pytest.mark.parametrize("prefix", ["txch"])
def test_is_valid_address_did_with_testnet(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"did:chia:14jxdtqcyp3gk8ka0678eq8mmtnktgpmp2vuqq3vtsl2e5qr7fyrsr9gsr7",
allowed_types={AddressType.DID},
config=config,
)
assert valid is True
@pytest.mark.parametrize("prefix", [None])
def test_is_valid_address_did_bad_address(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
valid = is_valid_address(
"did:chia:14jxdtqcyp3gk8ka0678eq8mmtnktgpmp2vuqq3vtsl2e5qr7fyrsrxxxxx",
allowed_types={AddressType.DID},
config=config,
)
assert valid is False
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_xch(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
address = ensure_valid_address(
"xch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs8taffd", allowed_types={AddressType.XCH}, config=config
)
assert address == "xch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs8taffd"
@pytest.mark.parametrize("prefix", ["txch"])
def test_ensure_valid_address_txch(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
address = ensure_valid_address(
"txch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs2v6lg7",
allowed_types={AddressType.XCH},
config=config,
)
assert address == "txch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs2v6lg7"
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_xch_bad_address(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
with pytest.raises(ValueError):
ensure_valid_address(
"xch1mnr0ygu7lvmk3nfgzmncfk39fwu0dv933yrcv97nd6pmrt7fzmhs8xxxxx",
allowed_types={AddressType.XCH},
config=config,
)
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_nft(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
address = ensure_valid_address(
"nft1mx2nkvml2eekjtqwdmxvmf3js8g083hpszzhkhtwvhcss8efqzhqtza773", allowed_types={AddressType.NFT}, config=config
)
assert address == "nft1mx2nkvml2eekjtqwdmxvmf3js8g083hpszzhkhtwvhcss8efqzhqtza773"
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_nft_bad_address(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
with pytest.raises(ValueError):
ensure_valid_address(
"nft1mx2nkvml2eekjtqwdmxvmf3js8g083hpszzhkhtwvhcss8efqzhqtxxxxx",
allowed_types={AddressType.NFT},
config=config,
)
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_did(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
address = ensure_valid_address(
"did:chia:14jxdtqcyp3gk8ka0678eq8mmtnktgpmp2vuqq3vtsl2e5qr7fyrsr9gsr7",
allowed_types={AddressType.DID},
config=config,
)
assert address == "did:chia:14jxdtqcyp3gk8ka0678eq8mmtnktgpmp2vuqq3vtsl2e5qr7fyrsr9gsr7"
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_did_bad_address(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
with pytest.raises(ValueError):
ensure_valid_address(
"did:chia:14jxdtqcyp3gk8ka0678eq8mmtnktgpmp2vuqq3vtsl2e5qr7fyrsrxxxxx",
allowed_types={AddressType.DID},
config=config,
)
@pytest.mark.parametrize("prefix", [None])
def test_ensure_valid_address_bad_length(config_with_address_prefix: Dict[str, Any]) -> None:
config = config_with_address_prefix
with pytest.raises(ValueError):
ensure_valid_address("xch1qqqqqqqqqqqqqqqqwygzk5", allowed_types={AddressType.XCH}, config=config)