Skip to content

Commit

Permalink
feat(EdgeNeighborLoader): update loader and gsql
Browse files Browse the repository at this point in the history
  • Loading branch information
billshitg committed Dec 22, 2023
1 parent 98cdc1f commit 7e0de75
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 93 deletions.
90 changes: 50 additions & 40 deletions pyTigerGraph/gds/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,8 @@ def _read_graph_data(
"Spektral is not installed. Please install it to use spektral output."
)
# Get raw data from queue and parse
vertex_buffer = []
edge_buffer = []
vertex_buffer = dict()
edge_buffer = dict()
buffer_size = 0
seeds = set()
is_empty = False
Expand All @@ -724,19 +724,17 @@ def _read_graph_data(
if buffer_size > 0:
last_batch = True
else:
vertex_buffer.extend(raw[0].strip().split("\n "))
edge_buffer.extend(raw[1].strip().split("\n "))
vertex_buffer.update({i.strip():"" for i in raw[0].strip().splitlines()})
edge_buffer.update({i.strip():"" for i in raw[1].strip().splitlines()})
seeds.add(raw[2])
buffer_size += 1
if (buffer_size < batch_size) and (not last_batch):
continue
try:
vertex_buffer_d = dict.fromkeys(vertex_buffer)
edge_buffer_d = dict.fromkeys(edge_buffer)
if seed_type:
raw_data = (vertex_buffer_d.keys(), edge_buffer_d.keys(), seeds)
raw_data = (vertex_buffer.keys(), edge_buffer.keys(), seeds)
else:
raw_data = (vertex_buffer_d.keys(), edge_buffer_d.keys())
raw_data = (vertex_buffer.keys(), edge_buffer.keys())
data = BaseLoader._parse_graph_data_to_df(
raw = raw_data,
v_in_feats = v_in_feats,
Expand Down Expand Up @@ -966,7 +964,7 @@ def _parse_vertex_data(
# Read in vertex CSVs as dataframes
# Each row is in format vid,v_in_feats,v_out_labels,v_extra_feats
# or vtype,vid,v_in_feats,v_out_labels,v_extra_feats
v_file = (line.strip().split(delimiter) for line in raw if line)
v_file = (line.split(delimiter) for line in raw)
# If seeds are given, create the is_seed column
if seeds:
seed_df = pd.DataFrame({
Expand Down Expand Up @@ -1030,7 +1028,7 @@ def _parse_edge_data(
# Read in edge CSVs as dataframes
# Each row is in format source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats
# or etype,source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats
e_file = (line.strip().split(delimiter) for line in raw if line)
e_file = (line.split(delimiter) for line in raw)
if not is_hetero:
e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats
if seeds:
Expand Down Expand Up @@ -1084,6 +1082,8 @@ def _parse_edge_data(
data[etype] = data[etype].merge(
tmp_df[["source", "target", "is_seed"]], on=["source", "target"], how="left")
data[etype].fillna({"is_seed": False}, inplace=True)
else:
data[etype]["is_seed"] = False
return data

@staticmethod
Expand Down Expand Up @@ -3817,7 +3817,8 @@ def _install_query(self, force: bool = False):

if self.is_hetero:
# Multiple vertex types
print_query = ""
print_query_seed = ""
print_query_other = ""
for idx, vtype in enumerate(self._vtypes):
v_attr_names = (
self.v_in_feats.get(vtype, [])
Expand All @@ -3826,17 +3827,25 @@ def _install_query(self, force: bool = False):
)
v_attr_types = self._v_schema[vtype]
print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types)
print_query += """
print_query_seed += """
{} s.type == "{}" THEN
@@v_batch += (s.type + delimiter + stringify(getvid(s)) {}+ "\\n")"""\
@@v_batch += (s->(s.type + delimiter + stringify(getvid(s)) {}+ "\\n"))"""\
.format("IF" if idx==0 else "ELSE IF", vtype,
"+ delimiter + " + print_attr if v_attr_names else "")
print_query += """
print_query_other += """
{} s.type == "{}" THEN
@@v_batch += (tmp_seed->(s.type + delimiter + stringify(getvid(s)) {}+ "\\n"))"""\
.format("IF" if idx==0 else "ELSE IF", vtype,
"+ delimiter + " + print_attr if v_attr_names else "")
print_query_seed += """
END"""
print_query_other += """
END"""
query_replace["{VERTEXATTRS}"] = print_query
query_replace["{SEEDVERTEXATTRS}"] = print_query_seed
query_replace["{OTHERVERTEXATTRS}"] = print_query_other
# Multiple edge types
print_query_seed = ""
print_query_other = ""
print_query = ""
print_query_kafka = ""
for idx, etype in enumerate(self._etypes):
e_attr_names = (
self.e_in_feats.get(etype, [])
Expand All @@ -3845,57 +3854,57 @@ def _install_query(self, force: bool = False):
)
e_attr_types = self._e_schema[etype]
print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types)
print_query_seed += """
print_query += """
{} e.type == "{}" THEN
@@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")"""\
@@e_batch += (tmp_seed->(e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"))"""\
.format("IF" if idx==0 else "ELSE IF", etype,
"+ delimiter + " + print_attr if e_attr_names else "")
print_query_other += """
print_query_kafka += """
{} e.type == "{}" THEN
@@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")"""\
SET<STRING> tmp_e = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n", ""),
tmp_e_batch = tmp_e_batch UNION tmp_e"""\
.format("IF" if idx==0 else "ELSE IF", etype,
"+ delimiter + "+ print_attr if e_attr_names else "")
print_query_seed += """
"+ delimiter + " + print_attr if e_attr_names else "")
print_query += """
END"""
print_query_other += """
print_query_kafka += """
END"""
query_replace["{SEEDEDGEATTRS}"] = print_query_seed
query_replace["{OTHEREDGEATTRS}"] = print_query_other
query_replace["{EDGEATTRS}"] = print_query
query_replace["{EDGEATTRSKAFKA}"] = print_query_kafka
else:
# Ignore vertex types
v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats
v_attr_types = next(iter(self._v_schema.values()))
print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types)
print_query = '@@v_batch += (stringify(getvid(s)) {}+ "\\n")'.format(
print_query_seed = '@@v_batch += (s->(stringify(getvid(s)) {}+ "\\n"))'.format(
"+ delimiter + " + print_attr if v_attr_names else ""
)
query_replace["{VERTEXATTRS}"] = print_query
print_query_other = '@@v_batch += (tmp_seed->(stringify(getvid(s)) {}+ "\\n"))'.format(
"+ delimiter + " + print_attr if v_attr_names else ""
)
query_replace["{SEEDVERTEXATTRS}"] = print_query_seed
query_replace["{OTHERVERTEXATTRS}"] = print_query_other
# Ignore edge types
e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats
e_attr_types = next(iter(self._e_schema.values()))
print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types)
print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")'.format(
print_query = '@@e_batch += (tmp_seed->(stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n"))'.format(
"+ delimiter + " + print_attr if e_attr_names else ""
)
query_replace["{SEEDEDGEATTRS}"] = print_query
print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")'.format(
query_replace["{EDGEATTRS}"] = print_query
print_query = """SET<STRING> tmp_e = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n", ""),
tmp_e_batch = tmp_e_batch UNION tmp_e""".format(
"+ delimiter + " + print_attr if e_attr_names else ""
)
query_replace["{OTHEREDGEATTRS}"] = print_query
query_replace["{EDGEATTRSKAFKA}"] = print_query
# Install query
query_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"gsql",
"dataloaders",
"edge_nei_loader.gsql",
)
sub_query_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"gsql",
"dataloaders",
"edge_nei_loader_sub.gsql",
)
return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query])
return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query)

def _start(self) -> None:
# Create task and result queues
Expand Down Expand Up @@ -3938,7 +3947,8 @@ def _start(self) -> None:
add_self_loop = self.add_self_loop,
delimiter = self.delimiter,
is_hetero = self.is_hetero,
callback_fn = self.callback_fn
callback_fn = self.callback_fn,
seed_type = "edge"
),
)
self._reader.start()
Expand Down
141 changes: 88 additions & 53 deletions pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql
Original file line number Diff line number Diff line change
Expand Up @@ -49,61 +49,123 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(
ssl_ca_location: Path to CA certificate for verifying the Kafka broker key.
*/
SumAccum<INT> @tmp_id;
SumAccum<STRING> @@kafka_error;
UINT producer;
SetAccum<VERTEX> @seeds;

start = {v_types};
# Filter seeds if needed
seeds = SELECT s
start = SELECT s
FROM start:s -(seed_types:e)- v_types:t
WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL")
POST-ACCUM s.@tmp_id = getvid(s)
POST-ACCUM t.@tmp_id = getvid(t);
# Shuffle vertex ID if needed
IF shuffle THEN
INT num_vertices = seeds.size();
INT num_vertices = start.size();
res = SELECT s
FROM seeds:s
FROM start:s
POST-ACCUM s.@tmp_id = floor(rand()*num_vertices)
LIMIT 1;
END;

# Generate batches
# If using kafka to export
IF kafka_address != "" THEN
SumAccum<STRING> @@kafka_error;

# Initialize Kafka producer
UINT producer = init_kafka_producer(
producer = init_kafka_producer(
kafka_address, kafka_max_size, security_protocol,
sasl_mechanism, sasl_username, sasl_password, ssl_ca_location,
ssl_certificate_location, ssl_key_location, ssl_key_password,
ssl_endpoint_identification_algorithm, sasl_kerberos_service_name,
sasl_kerberos_keytab, sasl_kerberos_principal);
END;

FOREACH chunk IN RANGE[0, num_chunks-1] DO
MapAccum<VERTEX, SetAccum<STRING>> @@v_batch;
MapAccum<VERTEX, SetAccum<STRING>> @@e_batch;

FOREACH chunk IN RANGE[0, num_chunks-1] DO
# Collect neighborhood data for each vertex
seed1 = SELECT s
FROM start:s -(seed_types:e)- v_types:t
WHERE (filter_by IS NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
;
seed2 = SELECT t
FROM start:s -(seed_types:e)- v_types:t
WHERE (filter_by IS NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
;
seeds = seed1 UNION seed2;
seeds = SELECT s
FROM seeds:s
POST-ACCUM
s.@seeds += s,
{SEEDVERTEXATTRS};
FOREACH hop IN RANGE[1, num_hops] DO
seeds = SELECT t
FROM seeds:s -(e_types:e)- v_types:t
SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1
ACCUM
t.@seeds += s.@seeds,
FOREACH tmp_seed in s.@seeds DO
{EDGEATTRS}
END;
seeds = SELECT s
FROM seeds:s
POST-ACCUM
FOREACH tmp_seed in s.@seeds DO
{OTHERVERTEXATTRS}
END;
END;
# Clear all accums
all_v = {v_types};
res = SELECT s
FROM all_v:s
POST-ACCUM [email protected]()
LIMIT 1;

# Generate output for each edge
# If use kafka to export
IF kafka_address != "" THEN
res = SELECT s
FROM seeds:s -(seed_types:e)- v_types:t
FROM seed1:s -(seed_types:e)- v_types:t
WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
ACCUM
STRING e_type = e.type,
LIST<STRING> msg = edge_nei_loader_sub_{QUERYSUFFIX}(s, t, delimiter, num_hops, num_neighbors, e_types, v_types, e_type),
BOOL is_first=True,
FOREACH i in msg DO
IF is_first THEN
INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), i),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n")
END,
is_first = False
ELSE
INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), i),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending edge batch for " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n")
END
END
INT part_num = (getvid(s)+getvid(t))%kafka_topic_partitions,
STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)),
SET<STRING> tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t),
INT kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending vertex batch for "+batch_id+": "+stringify(kafka_errcode) + "\n")
END,
SET<STRING> tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t),
{EDGEATTRSKAFKA},
kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)),
IF kafka_errcode!=0 THEN
@@kafka_error += ("Error sending edge batch for "+batch_id+ ": "+ stringify(kafka_errcode) + "\n")
END
LIMIT 1;
# Else return as http response
ELSE
MapAccum<STRING, STRING> @@v_data;
MapAccum<STRING, STRING> @@e_data;
res = SELECT s
FROM seed1:s -(seed_types:e)- v_types:t
WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
ACCUM
STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)),
SET<STRING> tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t),
@@v_data += (batch_id -> stringify(tmp_v_batch)),
SET<STRING> tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t),
{EDGEATTRSKAFKA},
@@e_data += (batch_id -> stringify(tmp_e_batch))
LIMIT 1;

FOREACH (k,v) IN @@v_data DO
PRINT v as vertex_batch, @@e_data.get(k) as edge_batch, k AS seed;
END;
END;

END;

IF kafka_address != "" THEN
FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO
INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", "");
IF kafka_errcode!=0 THEN
Expand All @@ -116,32 +178,5 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}(
@@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n");
END;
PRINT @@kafka_error as kafkaError;
# Else return as http response
ELSE
FOREACH chunk IN RANGE[0, num_chunks-1] DO
MapAccum<STRING, STRING> @@v_batch;
MapAccum<STRING, STRING> @@e_batch;

res = SELECT s
FROM seeds:s -(seed_types:e)- v_types:t
WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk)
ACCUM
STRING e_type = e.type,
LIST<STRING> msg = edge_nei_loader_sub_{QUERYSUFFIX}(s, t, delimiter, num_hops, num_neighbors, e_types, v_types, e_type),
BOOL is_first=True,
FOREACH i in msg DO
IF is_first THEN
@@v_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> i),
is_first = False
ELSE
@@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> i)
END
END
LIMIT 1;

FOREACH (k,v) IN @@v_batch DO
PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch;
END;
END;
END;
}

0 comments on commit 7e0de75

Please sign in to comment.