add `datacases()` and `named_datacases()` (#15265)

* add `datacases()` and `named_datacases()`

* correct DataCasesProtocol

* add back the tests for testing the test utilities
This commit is contained in:
Kyle Altendorf 2023-05-12 13:58:45 -04:00 committed by GitHub
parent d1ffa43ea8
commit 217429a126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 2 deletions

View File

@ -9,6 +9,10 @@ asyncio_mode = strict
markers = markers =
benchmark benchmark
data_layer: Mark as a data layer related test. data_layer: Mark as a data layer related test.
test_mark_a1: used in testing test utilities
test_mark_a2: used in testing test utilities
test_mark_b1: used in testing test utilities
test_mark_b2: used in testing test utilities
testpaths = tests testpaths = tests
filterwarnings = filterwarnings =
error error

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import enum import enum
import functools
import gc import gc
import math import math
import os import os
@ -13,10 +14,10 @@ from statistics import mean
from textwrap import dedent from textwrap import dedent
from time import thread_time from time import thread_time
from types import TracebackType from types import TracebackType
from typing import Any, Callable, Iterator, List, Optional, Type, Union from typing import Any, Callable, Collection, Iterator, List, Optional, Type, Union
import pytest import pytest
from typing_extensions import final from typing_extensions import Protocol, final
from tests.core.data_layer.util import ChiaRoot from tests.core.data_layer.util import ChiaRoot
@ -303,3 +304,31 @@ def closing_chia_root_popen(chia_root: ChiaRoot, args: List[str]) -> Iterator[su
process.wait(timeout=10) process.wait(timeout=10)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
process.kill() process.kill()
# https://github.com/pytest-dev/pytest/blob/7.3.1/src/_pytest/mark/__init__.py#L45
Marks = Union[pytest.MarkDecorator, Collection[Union[pytest.MarkDecorator, pytest.Mark]]]
class DataCase(Protocol):
marks: Marks
@property
def id(self) -> str:
...
def datacases(*cases: DataCase, _name: str = "case") -> pytest.MarkDecorator:
return pytest.mark.parametrize(
argnames=_name,
argvalues=[pytest.param(case, id=case.id, marks=case.marks) for case in cases],
)
class DataCasesDecorator(Protocol):
def __call__(self, *cases: DataCase, _name: str = "case") -> pytest.MarkDecorator:
...
def named_datacases(name: str) -> DataCasesDecorator:
return functools.partial(datacases, _name=name)

View File

@ -0,0 +1,38 @@
from __future__ import annotations
from dataclasses import dataclass
import pytest
from tests.util.misc import Marks, datacases, named_datacases
@dataclass
class DataCase:
id: str
marks: Marks
sample_cases = [
DataCase(id="id_a", marks=[pytest.mark.test_mark_a1, pytest.mark.test_mark_a2]),
DataCase(id="id_b", marks=[pytest.mark.test_mark_b1, pytest.mark.test_mark_b2]),
]
def sample_result(name: str) -> pytest.MarkDecorator:
return pytest.mark.parametrize(
argnames=name,
argvalues=[pytest.param(case, id=case.id, marks=case.marks) for case in sample_cases],
)
def test_datacases() -> None:
result = datacases(*sample_cases)
assert result == sample_result(name="case")
def test_named_datacases() -> None:
result = named_datacases("Sharrilanda")(*sample_cases)
assert result == sample_result(name="Sharrilanda")