Compare commits

...

4 Commits

Author SHA1 Message Date
Jeff Cruikshank 6b9efcddda
fix tests 2023-06-30 10:50:10 -07:00
Jeff Cruikshank 2847433573
linter fix 2023-06-30 10:50:09 -07:00
Jeff Cruikshank 8928d7a7a7
tests for get_wallet_addresses 2023-06-30 10:50:08 -07:00
Jeff Cruikshank 464b4b5433
add get_wallet_addresses RPC for deriving wallet addresses 2023-06-30 10:50:03 -07:00
2 changed files with 214 additions and 2 deletions

View File

@ -23,15 +23,18 @@ from typing_extensions import Protocol
from chia import __version__
from chia.cmds.init_funcs import check_keys, chia_full_version_str, chia_init
from chia.cmds.passphrase_funcs import default_passphrase, using_default_passphrase
from chia.consensus.coinbase import create_puzzlehash_for_pk
from chia.daemon.keychain_server import KeychainServer, keychain_commands
from chia.daemon.windows_signal import kill
from chia.plotters.plotters import get_available_plotters
from chia.plotting.util import add_plot_directory
from chia.server.server import ssl_context_for_server
from chia.util.bech32m import encode_puzzle_hash
from chia.util.beta_metrics import BetaMetricsLogger
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config
from chia.util.errors import KeychainCurrentPassphraseIsInvalid
from chia.util.ints import uint32
from chia.util.json_util import dict_to_json_str
from chia.util.keychain import Keychain, passphrase_requirements, supports_os_passphrase_storage
from chia.util.lock import Lockfile, LockfileError
@ -39,6 +42,7 @@ from chia.util.network import WebServer
from chia.util.service_groups import validate_service
from chia.util.setproctitle import setproctitle
from chia.util.ws_message import WsRpcMessage, create_payload, format_response
from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened
io_pool_exc = ThreadPoolExecutor()
@ -389,6 +393,7 @@ class WebSocketServer:
"get_version": self.get_version,
"get_plotters": self.get_plotters,
"get_routes": self.get_routes,
"get_wallet_addresses": self.get_wallet_addresses,
}
async def is_keyring_locked(self, websocket: WebSocketResponse, request: Dict[str, Any]) -> Dict[str, Any]:
@ -561,6 +566,58 @@ class WebSocketServer:
response: Dict[str, Any] = {"success": True, "routes": routes}
return response
async def get_wallet_addresses(self, websocket: WebSocketResponse, request: Dict[str, Any]) -> Dict[str, Any]:
all_keys = Keychain().get_keys(include_secrets=True)
fingerprints = request.get("fingerprints", None)
index = request.get("index", 0)
count = request.get("count", 1)
non_observer_derivation = request.get("non_observer_derivation", False)
# if fingerprints is None, we want all keys, otherwise we want the keys that match the fingerprints
if fingerprints is None:
keys = all_keys
else:
keys_by_fingerprint = {key.fingerprint: key for key in all_keys}
keys = []
missing_fingerprints = set()
for fingerprint in fingerprints:
if fingerprint not in keys_by_fingerprint:
missing_fingerprints.add(fingerprint)
else:
keys.append(keys_by_fingerprint[fingerprint])
if len(keys) != len(fingerprints):
return {"success": False, "error": f"key(s) not found for fingerprint(s) {missing_fingerprints}"}
selected = self.net_config["selected_network"]
prefix = self.net_config["network_overrides"]["config"][selected]["address_prefix"]
wallet_addresses_by_fingerprint = {}
for key in keys:
address_entries = []
# we require access to the private key to generate wallet addresses
if key.secrets is None:
return {"success": False, "error": f"missing private key for key with fingerprint {key.fingerprint}"}
for i in range(index, index + count):
if non_observer_derivation:
sk = master_sk_to_wallet_sk(key.secrets.private_key, uint32(i))
else:
sk = master_sk_to_wallet_sk_unhardened(key.secrets.private_key, uint32(i))
wallet_address = encode_puzzle_hash(create_puzzlehash_for_pk(sk.get_g1()), prefix)
if non_observer_derivation:
hd_path = f"m/12381n/8444n/2n/{i}n"
else:
hd_path = f"m/12381/8444/2/{i}"
address_entries.append({"address": wallet_address, "hd_path": hd_path})
wallet_addresses_by_fingerprint[key.fingerprint] = address_entries
response: Dict[str, Any] = {"success": True, "wallet_addresses": wallet_addresses_by_fingerprint}
return response
async def _keyring_status_changed(self, keyring_status: Dict[str, Any], destination: str):
"""
Attempt to communicate with the GUI to inform it of any keyring status changes

View File

@ -20,12 +20,14 @@ from chia.daemon.keychain_server import (
from chia.daemon.server import WebSocketServer, plotter_log_path, service_plotter
from chia.server.outbound_message import NodeType
from chia.simulator.block_tools import BlockTools
from chia.simulator.keyring import TempKeyring
from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval
from chia.types.peer_info import PeerInfo
from chia.util.config import load_config
from chia.util.ints import uint16
from chia.util.json_util import dict_to_json_str
from chia.util.keychain import Keychain, KeyData, supports_os_passphrase_storage
from chia.util.keyring_wrapper import DEFAULT_PASSPHRASE_IF_NO_MASTER_PASSPHRASE
from chia.util.keyring_wrapper import DEFAULT_PASSPHRASE_IF_NO_MASTER_PASSPHRASE, KeyringWrapper
from chia.util.ws_message import create_payload, create_payload_dict
from tests.core.node_height import node_height_at_least
from tests.util.misc import Marks, datacases
@ -60,6 +62,9 @@ class Daemon:
services: Dict[str, Union[List[Service], Service]]
connections: Dict[str, Optional[List[Any]]]
# Instance variables used by WebSocketServer.get_wallet_addresses()
net_config: Dict[str, Any]
def get_command_mapping(self) -> Dict[str, Any]:
return {
"get_routes": None,
@ -82,6 +87,11 @@ class Daemon:
cast(WebSocketServer, self), websocket=WebSocketResponse(), request=request
)
async def get_wallet_addresses(self, request: Dict[str, Any]) -> Dict[str, Any]:
return await WebSocketServer.get_wallet_addresses(
cast(WebSocketServer, self), websocket=WebSocketResponse(), request=request
)
test_key_data = KeyData.from_mnemonic(
"grief lock ketchup video day owner torch young work "
@ -90,6 +100,11 @@ test_key_data = KeyData.from_mnemonic(
)
test_key_data_no_secrets = replace(test_key_data, secrets=None)
test_key_data_2 = KeyData.from_mnemonic(
"banana boat fragile ghost fortune beyond aerobic access "
"hammer stable page grunt venture purse canyon discover "
"egg vivid spare immune awake code announce message"
)
success_response_data = {
"success": True,
@ -224,7 +239,7 @@ def assert_running_services_response(response_dict: Dict[str, Any], expected_res
@pytest.fixture(scope="session")
def mock_lonely_daemon():
# Mock daemon server without any registered services/connections
return Daemon(services={}, connections={})
return Daemon(services={}, connections={}, net_config={})
@pytest.fixture(scope="session")
@ -238,6 +253,7 @@ def mock_daemon_with_services():
"chia_plotter": [Service(True), Service(True)],
},
connections={},
net_config={},
)
@ -254,9 +270,31 @@ def mock_daemon_with_services_and_connections():
"apple": [1],
"banana": [1, 2],
},
net_config={},
)
@pytest.fixture(scope="function")
def get_keychain_for_function():
with TempKeyring() as keychain:
yield keychain
KeyringWrapper.cleanup_shared_instance()
@pytest.fixture(scope="function")
def mock_daemon_with_config_and_keys(get_keychain_for_function, root_path_populated_with_config):
root_path = root_path_populated_with_config
config = load_config(root_path, "config.yaml")
keychain = Keychain()
# populate the keychain with some test keys
keychain.add_private_key(test_key_data.mnemonic_str())
keychain.add_private_key(test_key_data_2.mnemonic_str())
# Mock daemon server with net_config set for mainnet
return Daemon(services={}, connections={}, net_config=config)
@pytest.mark.asyncio
async def test_daemon_simulation(self_hostname, daemon_simulation):
deamon_and_nodes, get_b_tools, bt = daemon_simulation
@ -438,6 +476,123 @@ async def test_get_routes(mock_lonely_daemon):
}
@pytest.mark.parametrize(
"rpc_request, pubkeys_only, expected_result",
[
# default case with no params -- returns first wallet address for each key
(
{},
False,
{
"success": True,
"wallet_addresses": {
test_key_data.fingerprint: [
{
"address": "xch1zze67l3jgxuvyaxhjhu7326sezxxve7lgzvq0497ddggzhff7c9s2pdcwh",
"hd_path": "m/12381/8444/2/0",
},
],
test_key_data_2.fingerprint: [
{
"address": "xch1fra5h0qnsezrxenjyslyxx7y4l268gq52m0rgenh58vn8f577uzswzvk4v",
"hd_path": "m/12381/8444/2/0",
}
],
},
},
),
# specifying a list of fingerprints will return the first wallet address for each listed key
(
{"fingerprints": [test_key_data.fingerprint]},
False,
{
"success": True,
"wallet_addresses": {
test_key_data.fingerprint: [
{
"address": "xch1zze67l3jgxuvyaxhjhu7326sezxxve7lgzvq0497ddggzhff7c9s2pdcwh",
"hd_path": "m/12381/8444/2/0",
},
],
},
},
),
# specifying count and index should return the correct wallet addresses
(
{"fingerprints": [test_key_data.fingerprint], "count": 2, "index": 1},
False,
{
"success": True,
"wallet_addresses": {
test_key_data.fingerprint: [
{
"address": "xch16jqcaguq27z8xvpu89j7eaqfzn6k89hdrrlm0rffku85n8n7m7sqqmmahh",
"hd_path": "m/12381/8444/2/1",
},
{
"address": "xch1955vj0gx5tqe7v5tceajn2p4z4pup8d4g2exs0cz4xjqses8ru6qu8zp3y",
"hd_path": "m/12381/8444/2/2",
},
]
},
},
),
# specifying non_observer_derivation=True should return addresses for hardened derivations
(
{"fingerprints": [test_key_data.fingerprint], "non_observer_derivation": True},
False,
{
"success": True,
"wallet_addresses": {
test_key_data.fingerprint: [
{
"address": "xch1k996a7h3agygjhqtrf0ycpa7wfd6k5ye2plkf54ukcmdj44gkqkq880l7n",
"hd_path": "m/12381n/8444n/2n/0n",
}
]
},
},
),
# specifying a list of fingerprints with one invalid fingerprint will return an error
(
{"fingerprints": [999999]},
False,
{
"success": False,
"error": "key(s) not found for fingerprint(s) {999999}",
},
),
(
{"fingerprints": [test_key_data.fingerprint]},
True,
{
"success": False,
"error": f"missing private key for key with fingerprint {test_key_data.fingerprint}",
},
),
],
)
@pytest.mark.asyncio
async def test_get_wallet_addresses(
mock_daemon_with_config_and_keys, monkeypatch, rpc_request, pubkeys_only, expected_result
):
daemon = mock_daemon_with_config_and_keys
original_get_keys = Keychain.get_keys
# monkeypatch Keychain.get_keys() to always call get_keys() with include_secrets=False
def mock_get_keys(self, include_secrets=False):
def wrapper(self, include_secrets):
return original_get_keys(self, include_secrets=False)
return wrapper
if pubkeys_only:
monkeypatch.setattr(Keychain, "get_keys", mock_get_keys(original_get_keys))
assert expected_result == await daemon.get_wallet_addresses(rpc_request)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_request, expected_result, expected_exception",