Add DataLayer clear pending roots interfaces (#15516)

* add clear pending roots cli and rpc

* add tests

* drop extracted changes

* remove extracted change

* tidy

* add --yes

* handle todo

* more todo
This commit is contained in:
Kyle Altendorf 2023-06-29 21:49:54 -04:00 committed by GitHub
parent f37b7dce1a
commit f66b521067
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 319 additions and 11 deletions

View file

@ -7,6 +7,8 @@ from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar, Unio
import click
from chia.types.blockchain_format.sized_bytes import bytes32
_T = TypeVar("_T")
@ -402,3 +404,25 @@ def check_plugins(
from chia.cmds.data_funcs import check_plugins_cmd
run(check_plugins_cmd(rpc_port=data_rpc_port))
@data_cmd.command(
"clear_pending_roots",
help="Clear pending roots that will not be published, associated data may not be recoverable",
)
@click.option("-i", "--id", "id_str", help="Store ID", type=str, required=True)
@click.confirmation_option(
prompt="Associated data may not be recoverable.\nAre you sure you want to remove the pending roots?",
)
@create_rpc_port_option()
def clear_pending_roots(id_str: str, data_rpc_port: int) -> None:
from chia.cmds.data_funcs import clear_pending_roots
store_id = bytes32.from_hexstr(id_str)
run(
clear_pending_roots(
rpc_port=data_rpc_port,
store_id=store_id,
)
)

View file

@ -3,13 +3,14 @@ from __future__ import annotations
import json
from decimal import Decimal
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from chia.cmds.cmds_util import get_any_service_client
from chia.cmds.units import units
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint64
@ -222,3 +223,16 @@ async def check_plugins_cmd(rpc_port: Optional[int]) -> None:
if client is not None:
res = await client.check_plugins()
print(json.dumps(res, indent=4, sort_keys=True))
async def clear_pending_roots(
store_id: bytes32,
rpc_port: Optional[int],
root_path: Path = DEFAULT_ROOT_PATH,
) -> Dict[str, Any]:
async with get_any_service_client(DataLayerRpcClient, rpc_port, root_path=root_path) as (client, _):
if client is not None:
result = await client.clear_pending_roots(store_id=store_id)
print(json.dumps(result, indent=4, sort_keys=True))
return result

View file

@ -299,6 +299,23 @@ class Root:
status=Status(row["status"]),
)
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> "Root":
return cls(
tree_id=bytes32.from_hexstr(marshalled["tree_id"]),
node_hash=None if marshalled["node_hash"] is None else bytes32.from_hexstr(marshalled["node_hash"]),
generation=marshalled["generation"],
status=Status(marshalled["status"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"tree_id": self.tree_id.hex(),
"node_hash": None if self.node_hash is None else self.node_hash.hex(),
"generation": self.generation,
"status": self.status.value,
}
node_type_to_class: Dict[NodeType, Union[Type[InternalNode], Type[TerminalNode]]] = {
NodeType.INTERNAL: InternalNode,
@ -617,6 +634,48 @@ class CancelOfferResponse:
}
@final
@dataclasses.dataclass(frozen=True)
class ClearPendingRootsRequest:
store_id: bytes32
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> ClearPendingRootsRequest:
return cls(
store_id=bytes32.from_hexstr(marshalled["store_id"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"store_id": self.store_id.hex(),
}
@final
@dataclasses.dataclass(frozen=True)
class ClearPendingRootsResponse:
success: bool
root: Optional[Root]
# tree_id: bytes32
# node_hash: Optional[bytes32]
# generation: int
# status: Status
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> ClearPendingRootsResponse:
return cls(
success=marshalled["success"],
root=None if marshalled["root"] is None else Root.unmarshal(marshalled["root"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"success": self.success,
"root": None if self.root is None else self.root.marshal(),
}
@dataclasses.dataclass(frozen=True)
class SyncStatus:
root_hash: bytes32

View file

@ -345,12 +345,17 @@ class DataStore:
return Root.from_row(row=row)
async def clear_pending_roots(self, tree_id: bytes32) -> None:
async def clear_pending_roots(self, tree_id: bytes32) -> Optional[Root]:
async with self.db_wrapper.writer() as writer:
await writer.execute(
"DELETE FROM root WHERE tree_id == :tree_id AND status == :status",
{"tree_id": tree_id, "status": Status.PENDING.value},
)
pending_root = await self.get_pending_root(tree_id=tree_id)
if pending_root is not None:
await writer.execute(
"DELETE FROM root WHERE tree_id == :tree_id AND status == :status",
{"tree_id": tree_id, "status": Status.PENDING.value},
)
return pending_root
async def shift_root_generations(self, tree_id: bytes32, shift_size: int) -> None:
async with self.db_wrapper.writer():

View file

@ -8,6 +8,8 @@ from chia.data_layer.data_layer_errors import OfferIntegrityError
from chia.data_layer.data_layer_util import (
CancelOfferRequest,
CancelOfferResponse,
ClearPendingRootsRequest,
ClearPendingRootsResponse,
MakeOfferRequest,
MakeOfferResponse,
Side,
@ -100,6 +102,7 @@ class DataLayerRpcApi:
"/cancel_offer": self.cancel_offer,
"/get_sync_status": self.get_sync_status,
"/check_plugins": self.check_plugins,
"/clear_pending_roots": self.clear_pending_roots,
}
async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]]) -> List[WsRpcMessage]:
@ -436,3 +439,9 @@ class DataLayerRpcApi:
plugin_status = await self.service.check_plugins()
return plugin_status.marshal()
@marshal() # type: ignore[arg-type]
async def clear_pending_roots(self, request: ClearPendingRootsRequest) -> ClearPendingRootsResponse:
root = await self.service.data_store.clear_pending_roots(tree_id=request.store_id)
return ClearPendingRootsResponse(success=root is not None, root=root)

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional
from chia.data_layer.data_layer_util import ClearPendingRootsRequest
from chia.rpc.rpc_client import RpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
@ -120,3 +121,8 @@ class DataLayerRpcClient(RpcClient):
async def check_plugins(self) -> Dict[str, Any]:
response = await self.fetch("check_plugins", {})
return response
async def clear_pending_roots(self, store_id: bytes32) -> Dict[str, Any]:
request = ClearPendingRootsRequest(store_id=store_id)
response = await self.fetch("clear_pending_roots", request.marshal())
return response

View file

@ -8,8 +8,18 @@ import pytest
# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest
from chia.data_layer.data_layer_util import ProofOfInclusion, ProofOfInclusionLayer, Side
from chia.data_layer.data_layer_util import (
ClearPendingRootsRequest,
ClearPendingRootsResponse,
ProofOfInclusion,
ProofOfInclusionLayer,
Root,
Side,
Status,
)
from chia.rpc.data_layer_rpc_util import MarshallableProtocol
from chia.types.blockchain_format.sized_bytes import bytes32
from tests.util.misc import Marks, datacases
pytestmark = pytest.mark.data_layer
@ -77,3 +87,47 @@ def test_proof_of_inclusion_is_valid(valid_proof_of_inclusion: ProofOfInclusion)
def test_proof_of_inclusion_is_invalid(invalid_proof_of_inclusion: ProofOfInclusion) -> None:
assert not invalid_proof_of_inclusion.valid()
@dataclasses.dataclass()
class RoundTripCase:
id: str
instance: MarshallableProtocol
marks: Marks = ()
@datacases(
RoundTripCase(
id="Root",
instance=Root(
tree_id=bytes32(b"\x00" * 32),
node_hash=bytes32(b"\x01" * 32),
generation=3,
status=Status.PENDING,
),
),
RoundTripCase(
id="ClearPendingRootsRequest",
instance=ClearPendingRootsRequest(store_id=bytes32(b"\x12" * 32)),
),
RoundTripCase(
id="ClearPendingRootsResponse success",
instance=ClearPendingRootsResponse(
success=True,
root=Root(
tree_id=bytes32(b"\x00" * 32),
node_hash=bytes32(b"\x01" * 32),
generation=3,
status=Status.PENDING,
),
),
),
RoundTripCase(
id="ClearPendingRootsResponse failure",
instance=ClearPendingRootsResponse(success=False, root=None),
),
)
def test_marshalling_round_trip(case: RoundTripCase) -> None:
marshalled = case.instance.marshal()
unmarshalled = type(case.instance).unmarshal(marshalled)
assert case.instance == unmarshalled

View file

@ -3,6 +3,10 @@ from __future__ import annotations
import asyncio
import contextlib
import copy
import enum
import json
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
@ -10,12 +14,15 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
import pytest
import pytest_asyncio
from chia.cmds.data_funcs import clear_pending_roots
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.data_layer.data_layer import DataLayer
from chia.data_layer.data_layer_api import DataLayerAPI
from chia.data_layer.data_layer_errors import OfferIntegrityError
from chia.data_layer.data_layer_util import OfferStore, StoreProofs
from chia.data_layer.data_layer_util import OfferStore, Status, StoreProofs
from chia.data_layer.data_layer_wallet import DataLayerWallet, verify_offer
from chia.rpc.data_layer_rpc_api import DataLayerRpcApi
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.rpc.wallet_rpc_api import WalletRpcApi
from chia.server.start_data_layer import create_data_layer_service
from chia.server.start_service import Service
@ -42,13 +49,20 @@ wallet_and_port_tuple = Tuple[WalletNode, uint16]
two_wallets_with_port = Tuple[Tuple[wallet_and_port_tuple, wallet_and_port_tuple], FullNodeSimulator, BlockTools]
class InterfaceLayer(enum.Enum):
direct = enum.auto()
client = enum.auto()
funcs = enum.auto()
cli = enum.auto()
@contextlib.asynccontextmanager
async def init_data_layer(
async def init_data_layer_service(
wallet_rpc_port: uint16,
bt: BlockTools,
db_path: Path,
wallet_service: Optional[Service[WalletNode, WalletNodeAPI]] = None,
) -> AsyncIterator[DataLayer]:
) -> AsyncIterator[Service[DataLayer, DataLayerAPI]]:
config = bt.config
config["data_layer"]["wallet_peer"]["port"] = int(wallet_rpc_port)
# TODO: running the data server causes the RPC tests to hang at the end
@ -62,12 +76,23 @@ async def init_data_layer(
)
await service.start()
try:
yield service._api.data_layer
yield service
finally:
service.stop()
await service.wait_closed()
@contextlib.asynccontextmanager
async def init_data_layer(
wallet_rpc_port: uint16,
bt: BlockTools,
db_path: Path,
wallet_service: Optional[Service[WalletNode, WalletNodeAPI]] = None,
) -> AsyncIterator[DataLayer]:
async with init_data_layer_service(wallet_rpc_port, bt, db_path, wallet_service) as data_layer_service:
yield data_layer_service._api.data_layer
@pytest_asyncio.fixture(name="bare_data_layer_api")
async def bare_data_layer_api_fixture(tmp_path: Path, bt: BlockTools) -> AsyncIterator[DataLayerRpcApi]:
# we won't use this port, this fixture is for _just_ a data layer rpc
@ -1832,3 +1857,96 @@ async def test_get_sync_status(
assert sync_status["target_root_hash"] != sync_status["root_hash"]
assert sync_status["generation"] == 2
assert sync_status["target_generation"] == 3
@pytest.mark.parametrize(argnames="layer", argvalues=list(InterfaceLayer))
@pytest.mark.asyncio
async def test_clear_pending_roots(
self_hostname: str,
one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices,
tmp_path: Path,
layer: InterfaceLayer,
bt: BlockTools,
) -> None:
wallet_rpc_api, full_node_api, wallet_rpc_port, ph, bt = await init_wallet_and_node(
self_hostname, one_wallet_and_one_simulator_services
)
async with init_data_layer_service(wallet_rpc_port=wallet_rpc_port, bt=bt, db_path=tmp_path) as data_layer_service:
# NOTE: we don't need the service for direct... simpler to leave it in
assert data_layer_service.rpc_server is not None
rpc_port = data_layer_service.rpc_server.listen_port
data_layer = data_layer_service._api.data_layer
# test insert
data_rpc_api = DataLayerRpcApi(data_layer)
data_store = data_layer.data_store
tree_id = bytes32(range(32))
await data_store.create_tree(tree_id=tree_id, status=Status.COMMITTED)
key = b"\x01\x02"
value = b"abc"
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=None,
side=None,
status=Status.PENDING,
)
pending_root = await data_store.get_pending_root(tree_id=tree_id)
assert pending_root is not None
if layer == InterfaceLayer.direct:
cleared_root = await data_rpc_api.clear_pending_roots({"store_id": tree_id.hex()})
elif layer == InterfaceLayer.funcs:
cleared_root = await clear_pending_roots(
store_id=tree_id,
rpc_port=rpc_port,
root_path=bt.root_path,
)
elif layer == InterfaceLayer.cli:
args: List[str] = [
sys.executable,
"-m",
"chia",
"data",
"clear_pending_roots",
"--id",
tree_id.hex(),
"--data-rpc-port",
str(rpc_port),
"--yes",
]
process = await asyncio.create_subprocess_exec(
*args,
env={**os.environ, "CHIA_ROOT": str(bt.root_path)},
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await process.wait()
assert process.stdout is not None
assert process.stderr is not None
stdout = await process.stdout.read()
cleared_root = json.loads(stdout)
stderr = await process.stderr.read()
assert process.returncode == 0
assert stderr == b""
elif layer == InterfaceLayer.client:
client = await DataLayerRpcClient.create(
self_hostname=self_hostname,
port=rpc_port,
root_path=bt.root_path,
net_config=bt.config,
)
try:
cleared_root = await client.clear_pending_roots(store_id=tree_id)
finally:
client.close()
await client.await_closed()
else: # pragma: no cover
assert False, "unhandled parametrization"
assert cleared_root == {"success": True, "root": pending_root.marshal()}

View file

@ -1266,3 +1266,22 @@ async def test_pending_roots(data_store: DataStore, tree_id: bytes32) -> None:
await data_store.clear_pending_roots(tree_id=tree_id)
pending_root = await data_store.get_pending_root(tree_id=tree_id)
assert pending_root is None
@pytest.mark.asyncio
async def test_clear_pending_roots_returns_root(data_store: DataStore, tree_id: bytes32) -> None:
key = b"\x01\x02"
value = b"abc"
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=None,
side=None,
status=Status.PENDING,
)
pending_root = await data_store.get_pending_root(tree_id=tree_id)
cleared_root = await data_store.clear_pending_roots(tree_id=tree_id)
assert cleared_root == pending_root