forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytest_shard_custom.py
67 lines (57 loc) · 2.25 KB
/
pytest_shard_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Custom pytest shard plugin
https://github.com/AdamGleave/pytest-shard/blob/64610a08dac6b0511b6d51cf895d0e1040d162ad/pytest_shard/pytest_shard.py#L1
Modifications:
* shards are now 1 indexed instead of 0 indexed
* option for printing items in shard
"""
import hashlib
from _pytest.config.argparsing import Parser
def pytest_addoptions(parser: Parser):
"""Add options to control sharding."""
group = parser.getgroup("shard")
group.addoption(
"--shard-id", dest="shard_id", type=int, default=1, help="Number of this shard."
)
group.addoption(
"--num-shards",
dest="num_shards",
type=int,
default=1,
help="Total number of shards.",
)
group.addoption(
"--print-items",
dest="print_items",
action="store_true",
default=False,
help="Print out the items being tested in this shard.",
)
class PytestShardPlugin:
def __init__(self, config):
self.config = config
def pytest_report_collectionfinish(self, config, items) -> str:
"""Log how many and which items are tested in this shard."""
msg = f"Running {len(items)} items in this shard"
if config.getoption("print_items"):
msg += ": " + ", ".join([item.nodeid for item in items])
return msg
def sha256hash(self, x: str) -> int:
return int.from_bytes(hashlib.sha256(x.encode()).digest(), "little")
def filter_items_by_shard(self, items, shard_id: int, num_shards: int):
"""Computes `items` that should be tested in `shard_id` out of `num_shards` total shards."""
new_items = [
item
for item in items
if self.sha256hash(item.nodeid) % num_shards == shard_id - 1
]
return new_items
def pytest_collection_modifyitems(self, config, items):
"""Mutate the collection to consist of just items to be tested in this shard."""
shard_id = config.getoption("shard_id")
shard_total = config.getoption("num_shards")
if shard_id < 1 or shard_id > shard_total:
raise ValueError(
f"{shard_id} is not a valid shard ID out of {shard_total} total shards"
)
items[:] = self.filter_items_by_shard(items, shard_id, shard_total)