Skip to content

Commit

Permalink
Update epoch_metric and recall in test for generating data with diffe…
Browse files Browse the repository at this point in the history
…rent rank (#2675)

* Update epoch_metric and recall in test for generating data with different rank

Update `epoch_metric` and `recall`

* Update test_recall.py

* Update test_recall.py

Co-authored-by: Sadra Barikbin <[email protected]>
  • Loading branch information
puhuk and sadra-barikbin authored Aug 29, 2022
1 parent 02d4c81 commit 942af82
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
22 changes: 12 additions & 10 deletions tests/ignite/metrics/test_epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,34 +159,36 @@ def _test_distrib_integration(device=None):
device = idist.device() if idist.device().type != "xla" else "cpu"

rank = idist.get_rank()
torch.manual_seed(12)
torch.manual_seed(12 + rank)

n_iters = 60
s = 16
n_iters = 3
batch_size = 2
n_classes = 7

offset = n_iters * s
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),), device=device)
y_preds = torch.rand(offset * idist.get_world_size(), n_classes, device=device)
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,), device=device)
y_preds = torch.rand(n_iters * batch_size, n_classes, device=device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size, :],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

def assert_data_fn(all_preds, all_targets):
assert all_preds.equal(y_preds), f"{all_preds.shape} vs {y_preds.shape}"
assert all_targets.equal(y_true), f"{all_targets.shape} vs {y_true.shape}"
return (all_preds.argmax(dim=1) == all_targets).sum().item()

ep_metric = EpochMetric(assert_data_fn, check_compute_fn=False, device=device)
ep_metric.attach(engine, "epm")

data = list(range(n_iters))

engine.run(data=data, max_epochs=3)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert engine.state.metrics["epm"] == (y_preds.argmax(dim=1) == y_true).sum().item()


Expand Down
40 changes: 22 additions & 18 deletions tests/ignite/metrics/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,22 +430,18 @@ def _test_distrib_integration_multiclass(device):

from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12)

def _test(average, n_epochs, metric_device):
n_iters = 60
s = 16
batch_size = 16
n_classes = 7

offset = n_iters * s
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size, :],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
Expand All @@ -457,6 +453,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "re" in engine.state.metrics
assert re._updated is True
res = engine.state.metrics["re"]
Expand All @@ -475,7 +474,9 @@ def update(engine, i):
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(2):
rank = idist.get_rank()
for i in range(2):
torch.manual_seed(12 + rank + i)
for metric_device in metric_devices:
_test(average=False, n_epochs=1, metric_device=metric_device)
_test(average=False, n_epochs=2, metric_device=metric_device)
Expand All @@ -491,22 +492,20 @@ def _test_distrib_integration_multilabel(device):

from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12)

def _test(average, n_epochs, metric_device):
n_iters = 60
s = 16
batch_size = 16
n_classes = 7

offset = n_iters * s
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
y_preds[i * batch_size : (i + 1) * batch_size, ...],
y_true[i * batch_size : (i + 1) * batch_size, ...],
)

engine = Engine(update)
Expand All @@ -518,6 +517,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "re" in engine.state.metrics
assert re._updated is True
res = engine.state.metrics["re"]
Expand All @@ -540,7 +542,9 @@ def update(engine, i):
metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(2):
rank = idist.get_rank()
for i in range(2):
torch.manual_seed(12 + rank + i)
for metric_device in metric_devices:
_test(average=False, n_epochs=1, metric_device=metric_device)
_test(average=False, n_epochs=2, metric_device=metric_device)
Expand Down

0 comments on commit 942af82

Please sign in to comment.