Skip to content

Commit

Permalink
feat: allow kafka in distributed query
Browse files Browse the repository at this point in the history
  • Loading branch information
billshitg committed Jan 4, 2024
1 parent 4d334a1 commit c704e02
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 26 deletions.
14 changes: 11 additions & 3 deletions pyTigerGraph/gds/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def __init__(
kafka_add_topic_per_epoch: bool = False,
callback_fn: Callable = None,
kafka_group_id: str = None,
kafka_topic: str = None
kafka_topic: str = None,
num_machines: int = 1,
num_segments: int = 20,
) -> None:
"""Base Class for data loaders.
Expand Down Expand Up @@ -291,6 +293,8 @@ def __init__(
)
# Initialize parameters for the query
self._payload = {}
self._payload["num_machines"] = num_machines
self._payload["num_segments"] = num_segments
if self.kafka_address_producer:
self._payload["kafka_address"] = self.kafka_address_producer
self._payload["kafka_topic_partitions"] = kafka_num_partitions
Expand Down Expand Up @@ -3659,7 +3663,9 @@ def __init__(
kafka_add_topic_per_epoch: bool = False,
callback_fn: Callable = None,
kafka_group_id: str = None,
kafka_topic: str = None
kafka_topic: str = None,
num_machines: int = 1,
num_segments: int = 20
) -> None:
"""NO DOC"""

Expand Down Expand Up @@ -3704,7 +3710,9 @@ def __init__(
kafka_add_topic_per_epoch,
callback_fn,
kafka_group_id,
kafka_topic
kafka_topic,
num_machines,
num_segments
)
# Resolve attributes
is_hetero = any(map(lambda x: isinstance(x, dict),
Expand Down
12 changes: 9 additions & 3 deletions pyTigerGraph/gds/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,9 @@ def edgeNeighborLoader(
timeout: int = 300000,
callback_fn: Callable = None,
reinstall_query: bool = False,
distributed_query: bool = False
distributed_query: bool = False,
num_machines: int = 1,
num_segments: int = 20
) -> EdgeNeighborLoader:
"""Returns an `EdgeNeighborLoader` instance.
An `EdgeNeighborLoader` instance performs neighbor sampling from all edges in the graph in batches in the following manner:
Expand Down Expand Up @@ -1098,7 +1100,9 @@ def edgeNeighborLoader(
"delimiter": delimiter,
"timeout": timeout,
"callback_fn": callback_fn,
"distributed_query": distributed_query
"distributed_query": distributed_query,
"num_machines": num_machines,
"num_segments": num_segments
}
if self.kafkaConfig:
params.update(self.kafkaConfig)
Expand Down Expand Up @@ -1130,7 +1134,9 @@ def edgeNeighborLoader(
"delimiter": delimiter,
"timeout": timeout,
"callback_fn": callback_fn,
"distributed_query": distributed_query
"distributed_query": distributed_query,
"num_machines": num_machines,
"num_segments": num_segments
}
if self.kafkaConfig:
params.update(self.kafkaConfig)
Expand Down
70 changes: 50 additions & 20 deletions pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(
SET<STRING> seed_types,
STRING delimiter,
INT num_chunks=2,
INT num_machines=1,
INT num_segments=20,
STRING kafka_address="",
STRING kafka_topic="",
INT kafka_topic_partitions=1,
Expand Down Expand Up @@ -50,8 +52,10 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(
*/
SumAccum<INT> @tmp_id;
SumAccum<STRING> @@kafka_error;
UINT producer;
SetAccum<VERTEX> @seeds;
MapAccum<INT, MinAccum<UINT>> @@mid_to_vid; # This tmp accumulator maps machine ID to the smallest vertex ID on the machine.
MapAccum<INT, UINT> @@mid_to_producer;
SumAccum<UINT> @kafka_producer_id;

start = {v_types};
# Filter seeds if needed
Expand All @@ -71,13 +75,32 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(

# If using kafka to export
IF kafka_address != "" THEN
# Initialize Kafka producer
producer = init_kafka_producer(
kafka_address, kafka_max_size, security_protocol,
sasl_mechanism, sasl_username, sasl_password, ssl_ca_location,
ssl_certificate_location, ssl_key_location, ssl_key_password,
ssl_endpoint_identification_algorithm, sasl_kerberos_service_name,
sasl_kerberos_keytab, sasl_kerberos_principal);
# We generate a vertex set that contains exactly one vertex per machine.
machine_set =
SELECT s
FROM start:s
ACCUM
INT mid = (getvid(s) >> num_segments & 31) % num_machines,
@@mid_to_vid += (mid -> getvid(s))
HAVING @@mid_to_vid.get((getvid(s) >> num_segments & 31) % num_machines) == getvid(s);
@@mid_to_vid.clear();
# Initialize Kafka producer per machine
res = SELECT s
FROM machine_set:s
ACCUM
INT mid = (getvid(s) >> num_segments & 31) % num_machines,
UINT producer = init_kafka_producer(
kafka_address, kafka_max_size, security_protocol,
sasl_mechanism, sasl_username, sasl_password, ssl_ca_location,
ssl_certificate_location, ssl_key_location, ssl_key_password,
ssl_endpoint_identification_algorithm, sasl_kerberos_service_name,
sasl_kerberos_keytab, sasl_kerberos_principal),
@@mid_to_producer += (mid -> producer);
res = SELECT s
FROM start:s
ACCUM
INT mid = (getvid(s) >> num_segments & 31) % num_machines,
s.@kafka_producer_id += @@mid_to_producer.get(mid);
END;

FOREACH chunk IN RANGE[0, num_chunks-1] DO
Expand Down Expand Up @@ -132,13 +155,13 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(
INT part_num = (getvid(s)+getvid(t))%kafka_topic_partitions,
STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)),
SET<STRING> tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t),
INT kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)),
INT kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending vertex batch for "+batch_id+": "+stringify(kafka_errcode) + "\n")
END,
SET<STRING> tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t),
{EDGEATTRSKAFKA},
kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)),
kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending edge batch for "+batch_id+ ": "+ stringify(kafka_errcode) + "\n")
END
Expand Down Expand Up @@ -166,17 +189,24 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(
END;

IF kafka_address != "" THEN
FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO
INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", "");
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n");
END;
END;
res = SELECT s
FROM machine_set:s
WHERE (getvid(s) >> num_segments & 31) % num_machines == 0
ACCUM
FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO
INT kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, i, "STOP", ""),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n")
END
END;

INT kafka_errcode = close_kafka_producer(producer, kafka_timeout);
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n");
END;
res = SELECT s
FROM machine_set:s
ACCUM
INT kafka_errcode = close_kafka_producer(s.@kafka_producer_id, kafka_timeout),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n")
END;
PRINT @@kafka_error as kafkaError;
END;
}
32 changes: 32 additions & 0 deletions tests/test_gds_EdgeNeighborLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ def test_iterate_pyg(self):
self.assertEqual(i, 1024)
self.assertLessEqual(batch_sizes[-1], 1024)

def test_iterate_pyg_distributed(self):
loader = EdgeNeighborLoader(
graph=self.conn,
v_in_feats=["x"],
e_extra_feats=["is_train"],
batch_size=1024,
num_neighbors=10,
num_hops=2,
shuffle=True,
filter_by=None,
output_format="PyG",
kafka_address="kafka:9092",
distributed_query=True
)
num_batches = 0
batch_sizes = []
for data in loader:
# print(num_batches, data)
self.assertIsInstance(data, pygData)
self.assertIn("x", data)
self.assertIn("is_seed", data)
self.assertIn("is_train", data)
self.assertGreater(data["x"].shape[0], 0)
self.assertGreater(data["edge_index"].shape[1], 0)
num_batches += 1
batch_sizes.append(int(data["is_seed"].sum()))
self.assertEqual(num_batches, 11)
for i in batch_sizes[:-1]:
self.assertEqual(i, 1024)
self.assertLessEqual(batch_sizes[-1], 1024)

def test_sasl_ssl(self):
loader = EdgeNeighborLoader(
graph=self.conn,
Expand Down Expand Up @@ -312,6 +343,7 @@ def test_iterate_pyg(self):
suite = unittest.TestSuite()
suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_init"))
suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg"))
suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg_distributed"))
# suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_sasl_ssl"))
suite.addTest(TestGDSEdgeNeighborLoaderREST("test_init"))
suite.addTest(TestGDSEdgeNeighborLoaderREST("test_iterate_pyg"))
Expand Down

0 comments on commit c704e02

Please sign in to comment.