Skip to content

Commit

Permalink
tests(EdgeNeighborLoader): test distributed query
Browse files Browse the repository at this point in the history
  • Loading branch information
billshitg committed Feb 12, 2024
1 parent c704e02 commit 0ae8be8
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/test_gds_EdgeNeighborLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,53 @@ def test_iterate_pyg(self):
self.assertEqual(i, 100)
self.assertLessEqual(batch_sizes[-1], 100)

def test_iterate_pyg_distributed(self):
loader = EdgeNeighborLoader(
graph=self.conn,
v_in_feats={"v0": ["x", "y"], "v2": ["x"]},
e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]},
e_seed_types=["v2v0"],
batch_size=100,
num_neighbors=5,
num_hops=2,
shuffle=True,
filter_by=None,
output_format="PyG",
kafka_address="kafka:9092",
distributed_query=True
)
num_batches = 0
batch_sizes = []
for data in loader:
# print(num_batches, data)
self.assertIsInstance(data, pygHeteroData)
self.assertGreater(data["v0"]["x"].shape[0], 0)
self.assertGreater(data["v2"]["x"].shape[0], 0)
self.assertTrue(
data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0
and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943
)
self.assertEqual(
data['v2', 'v2v0', 'v0']["edge_index"].shape[1],
data['v2', 'v2v0', 'v0']["is_train"].shape[0]
)
if ('v0', 'v0v0', 'v0') in data.edge_types:
self.assertTrue(
data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0
and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710
)
if ('v2', 'v2v2', 'v2') in data.edge_types:
self.assertTrue(
data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0
and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966
)
num_batches += 1
batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum()))
self.assertEqual(num_batches, 10)
for i in batch_sizes[:-1]:
self.assertEqual(i, 100)
self.assertLessEqual(batch_sizes[-1], 100)


if __name__ == "__main__":
suite = unittest.TestSuite()
Expand All @@ -352,5 +399,6 @@ def test_iterate_pyg(self):
suite.addTest(TestGDSHeteroEdgeNeighborLoaderREST("test_iterate_pyg"))
suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_init"))
suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg"))
suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg_distributed"))
runner = unittest.TextTestRunner(verbosity=2, failfast=True)
runner.run(suite)

0 comments on commit 0ae8be8

Please sign in to comment.