diff --git a/ci/test_notebooks.sh b/ci/test_notebooks.sh index 00c63af..73cc315 100755 --- a/ci/test_notebooks.sh +++ b/ci/test_notebooks.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. set -Eeuo pipefail @@ -7,11 +7,19 @@ set -Eeuo pipefail RAPIDS_VERSION="$(rapids-version)" +rapids-logger "Downloading artifacts from previous jobs" +CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) +PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python) + rapids-logger "Generate notebook testing dependencies" rapids-dependency-file-generator \ --output conda \ --file-key test_notebooks \ - --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml + --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" \ + --prepend-channel dglteam/label/th23_cu118 \ + --prepend-channel "${CPP_CHANNEL}" \ + --prepend-channel "${PYTHON_CHANNEL}" \ +| tee env.yaml rapids-mamba-retry env create --yes -f env.yaml -n test @@ -22,16 +30,6 @@ set -u rapids-print-env -rapids-logger "Downloading artifacts from previous jobs" -CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) -PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python) - -rapids-mamba-retry install \ - --channel "${CPP_CHANNEL}" \ - --channel "${PYTHON_CHANNEL}" \ - --channel dglteam/label/th23_cu118 \ - "cugraph-dgl=${RAPIDS_VERSION}" - NBTEST="$(realpath "$(dirname "$0")/utils/nbtest.sh")" NOTEBOOK_LIST="$(realpath "$(dirname "$0")/notebook_list.py")" EXITCODE=0 diff --git a/dependencies.yaml b/dependencies.yaml index 85bdac9..c5a423c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -54,6 +54,7 @@ files: includes: - cuda_version - depends_on_pytorch + - depends_on_cugraph_dgl - py_version - test_notebook test_python: @@ -540,6 +541,12 @@ dependencies: - cugraph-cu11==25.2.*,>=0.0.0a0 - {matrix: null, packages: [*cugraph_unsuffixed]} + depends_on_cugraph_dgl: + common: + - output_types: conda + packages: + - cugraph-dgl==25.2.*,>=0.0.0a0 + depends_on_cudf: common: - output_types: conda diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py index ecc5100..16e106b 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -14,11 +14,10 @@ from __future__ import annotations import warnings -import tempfile from typing import Sequence, Optional, Union, List, Tuple, Iterator -from cugraph.gnn import UniformNeighborSampler, BiasedNeighborSampler, DistSampleWriter +from cugraph.gnn import UniformNeighborSampler, BiasedNeighborSampler from cugraph.utilities.utils import import_optional import cugraph_dgl @@ -124,7 +123,7 @@ def __init__( Can be either "dgl.Block" (default), or "cugraph_dgl.nn.SparseGraph". **kwargs Keyword arguments for the underlying cuGraph distributed sampler - and writer (directory, batches_per_partition, format, + and writer (batches_per_partition, format, local_seeds_per_call). """ @@ -165,18 +164,6 @@ def sample( ) -> Iterator[DGLSamplerOutput]: kwargs = dict(**self.__kwargs) - directory = kwargs.pop("directory", None) - if directory is None: - warnings.warn("Setting a directory to store samples is recommended.") - self._tempdir = tempfile.TemporaryDirectory() - directory = self._tempdir.name - - writer = DistSampleWriter( - directory=directory, - batches_per_partition=kwargs.pop("batches_per_partition", 256), - format=kwargs.pop("format", "parquet"), - ) - sampling_clx = ( UniformNeighborSampler if self.__prob_attr is None @@ -185,7 +172,7 @@ def sample( ds = sampling_clx( g._graph(self.edge_dir, prob_attr=self.__prob_attr), - writer, + writer=None, compression="CSR", fanout=self._reversed_fanout_vals, prior_sources_behavior="carryover", diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py index c47dda5..e14e6ca 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,6 +70,7 @@ def __clear_graph(self): self.__graph = None self.__vertex_offsets = None self.__weight_attr = None + self.__numeric_edge_types = None def _put_edge_index( self, @@ -240,6 +241,27 @@ def _vertex_offsets(self) -> Dict[str, int]: return dict(self.__vertex_offsets) + @property + def _vertex_offset_array(self) -> "torch.Tensor": + off = torch.tensor( + [self._vertex_offsets[k] for k in sorted(self._vertex_offsets.keys())], + dtype=torch.int64, + device="cuda", + ) + + return torch.concat( + [ + off, + torch.tensor( + list(self._num_vertices().values()), + device="cuda", + dtype=torch.int64, + ) + .sum() + .reshape((1,)), + ] + ) + @property def is_homogeneous(self) -> bool: return len(self._vertex_offsets) == 1 @@ -270,6 +292,38 @@ def __get_weight_tensor( return torch.concat(weights) + @property + def _numeric_edge_types(self) -> Tuple[List, "torch.Tensor", "torch.Tensor"]: + """ + Returns the canonical edge types in order (the 0th canonical type corresponds + to numeric edge type 0, etc.), along with the numeric source and destination + vertex types for each edge type. + """ + + if self.__numeric_edge_types is None: + sorted_keys = sorted( + list(self.__edge_indices.keys(leaves_only=True, include_nested=True)) + ) + + vtype_table = { + k: i for i, k in enumerate(sorted(self._vertex_offsets.keys())) + } + + srcs = [] + dsts = [] + + for can_etype in sorted_keys: + srcs.append(vtype_table[can_etype[0]]) + dsts.append(vtype_table[can_etype[2]]) + + self.__numeric_edge_types = ( + sorted_keys, + torch.tensor(srcs, device="cuda", dtype=torch.int32), + torch.tensor(dsts, device="cuda", dtype=torch.int32), + ) + + return self.__numeric_edge_types + def __get_edgelist(self): """ Returns diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py index 1e19aa7..304c360 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,6 @@ import argparse import os import warnings -import tempfile import time import json @@ -179,7 +178,6 @@ def run_train( fan_out, num_classes, wall_clock_start, - tempdir=None, num_layers=3, in_memory=False, seeds_per_call=-1, @@ -194,13 +192,9 @@ def run_train( from cugraph_pyg.loader import NeighborLoader ix_train = split_idx["train"].cuda() - train_path = None if in_memory else os.path.join(tempdir, f"train_{global_rank}") - if train_path: - os.mkdir(train_path) train_loader = NeighborLoader( data, input_nodes=ix_train, - directory=train_path, shuffle=True, drop_last=True, local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, @@ -208,13 +202,9 @@ def run_train( ) ix_test = split_idx["test"].cuda() - test_path = None if in_memory else os.path.join(tempdir, f"test_{global_rank}") - if test_path: - os.mkdir(test_path) test_loader = NeighborLoader( data, input_nodes=ix_test, - directory=test_path, shuffle=True, drop_last=True, local_seeds_per_call=80000, @@ -222,13 +212,9 @@ def run_train( ) ix_valid = split_idx["valid"].cuda() - valid_path = None if in_memory else os.path.join(tempdir, f"valid_{global_rank}") - if valid_path: - os.mkdir(valid_path) valid_loader = NeighborLoader( data, input_nodes=ix_valid, - directory=valid_path, shuffle=True, drop_last=True, local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, @@ -347,7 +333,6 @@ def parse_args(): parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--fan_out", type=int, default=30) - parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="datasets") parser.add_argument("--dataset", type=str, default="ogbn-products") parser.add_argument("--skip_partition", action="store_true") @@ -427,23 +412,21 @@ def parse_args(): ).to(device) model = DistributedDataParallel(model, device_ids=[local_rank]) - with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir: - run_train( - global_rank, - data, - split_idx, - world_size, - device, - model, - args.epochs, - args.batch_size, - args.fan_out, - meta["num_classes"], - wall_clock_start, - tempdir, - args.num_layers, - args.in_memory, - args.seeds_per_call, - ) + run_train( + global_rank, + data, + split_idx, + world_size, + device, + model, + args.epochs, + args.batch_size, + args.fan_out, + meta["num_classes"], + wall_clock_start, + args.num_layers, + args.in_memory, + args.seeds_per_call, + ) else: warnings.warn("This script should be run with 'torchrun`. Exiting.") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py index 223b58f..736dede 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,7 +13,6 @@ import time import argparse -import tempfile import os import warnings @@ -96,22 +95,15 @@ def create_loader( input_nodes, replace, batch_size, - samples_dir, stage_name, local_seeds_per_call, ): - if samples_dir is not None: - directory = os.path.join(samples_dir, stage_name) - os.mkdir(directory) - else: - directory = None return NeighborLoader( data, num_neighbors=num_neighbors, input_nodes=input_nodes, replace=replace, batch_size=batch_size, - directory=directory, local_seeds_per_call=local_seeds_per_call, ) @@ -155,7 +147,6 @@ def parse_args(): parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--fan_out", type=int, default=30) - parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="datasets") parser.add_argument("--dataset", type=str, default="ogbn-products") parser.add_argument("--in_memory", action="store_true", default=False) @@ -177,60 +168,56 @@ def parse_args(): warnings.warn("Pruning test dataset for CI run.") split_idx["test"] = split_idx["test"][:1000] - with tempfile.TemporaryDirectory(dir=args.tempdir_root) as samples_dir: - loader_kwargs = { - "data": data, - "num_neighbors": [args.fan_out] * args.num_layers, - "replace": False, - "batch_size": args.batch_size, - "samples_dir": None if args.in_memory else samples_dir, - "local_seeds_per_call": None - if args.seeds_per_call <= 0 - else args.seeds_per_call, - } - - train_loader = create_loader( - input_nodes=split_idx["train"], - stage_name="train", - **loader_kwargs, - ) - - val_loader = create_loader( - input_nodes=split_idx["valid"], - stage_name="val", - **loader_kwargs, - ) - - test_loader = create_loader( - input_nodes=split_idx["test"], - stage_name="test", - **loader_kwargs, - ) - - model = torch_geometric.nn.models.GCN( - num_features, - args.hidden_channels, - args.num_layers, - num_classes, - ).to(device) - - optimizer = torch.optim.Adam( - model.parameters(), lr=args.lr, weight_decay=0.0005 - ) - - warmup_steps = 20 - - torch.cuda.synchronize() - prep_time = round(time.perf_counter() - wall_clock_start, 2) - print("Total time before training begins (prep_time)=", prep_time, "seconds") - print("Beginning training...") - for epoch in range(1, 1 + args.epochs): - train(epoch) - val_acc = test(val_loader, val_steps=100) - print(f"Val Acc: ~{val_acc:.4f}") - - test_acc = test(test_loader) - print(f"Test Acc: {test_acc:.4f}") - total_time = round(time.perf_counter() - wall_clock_start, 2) - print("Total Program Runtime (total_time) =", total_time, "seconds") - print("total_time - prep_time =", total_time - prep_time, "seconds") + loader_kwargs = { + "data": data, + "num_neighbors": [args.fan_out] * args.num_layers, + "replace": False, + "batch_size": args.batch_size, + "local_seeds_per_call": None + if args.seeds_per_call <= 0 + else args.seeds_per_call, + } + + train_loader = create_loader( + input_nodes=split_idx["train"], + stage_name="train", + **loader_kwargs, + ) + + val_loader = create_loader( + input_nodes=split_idx["valid"], + stage_name="val", + **loader_kwargs, + ) + + test_loader = create_loader( + input_nodes=split_idx["test"], + stage_name="test", + **loader_kwargs, + ) + + model = torch_geometric.nn.models.GCN( + num_features, + args.hidden_channels, + args.num_layers, + num_classes, + ).to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) + + warmup_steps = 20 + + torch.cuda.synchronize() + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, "seconds") + print("Beginning training...") + for epoch in range(1, 1 + args.epochs): + train(epoch) + val_acc = test(val_loader, val_steps=100) + print(f"Val Acc: ~{val_acc:.4f}") + + test_acc = test(test_loader) + print(f"Test Acc: {test_acc:.4f}") + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py index 42e3343..db335da 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,6 @@ import argparse import os -import tempfile import time import warnings @@ -80,9 +79,7 @@ def run_train( split_idx, num_classes, wall_clock_start, - tempdir=None, num_layers=3, - in_memory=False, seeds_per_call=-1, ): @@ -117,13 +114,9 @@ def run_train( dist.barrier() ix_train = torch.tensor_split(split_idx["train"], world_size)[rank].cuda() - train_path = None if in_memory else os.path.join(tempdir, f"train_{rank}") - if train_path: - os.mkdir(train_path) train_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_train, - directory=train_path, shuffle=True, drop_last=True, local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, @@ -131,13 +124,9 @@ def run_train( ) ix_test = torch.tensor_split(split_idx["test"], world_size)[rank].cuda() - test_path = None if in_memory else os.path.join(tempdir, f"test_{rank}") - if test_path: - os.mkdir(test_path) test_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_test, - directory=test_path, shuffle=True, drop_last=True, local_seeds_per_call=80000, @@ -145,13 +134,9 @@ def run_train( ) ix_valid = torch.tensor_split(split_idx["valid"], world_size)[rank].cuda() - valid_path = None if in_memory else os.path.join(tempdir, f"valid_{rank}") - if valid_path: - os.mkdir(valid_path) valid_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_valid, - directory=valid_path, shuffle=True, drop_last=True, local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, @@ -271,10 +256,8 @@ def run_train( parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--fan_out", type=int, default=30) - parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="datasets") parser.add_argument("--dataset", type=str, default="ogbn-products") - parser.add_argument("--in_memory", action="store_true", default=False) parser.add_argument("--seeds_per_call", type=int, default=-1) parser.add_argument( @@ -315,25 +298,22 @@ def run_train( cugraph_id = cugraph_comms_create_unique_id() - with tempfile.TemporaryDirectory(dir=args.tempdir_root) as tempdir: - mp.spawn( - run_train, - args=( - data, - world_size, - cugraph_id, - model, - args.epochs, - args.batch_size, - args.fan_out, - split_idx, - dataset.num_classes, - wall_clock_start, - tempdir, - args.num_layers, - args.in_memory, - args.seeds_per_call, - ), - nprocs=world_size, - join=True, - ) + mp.spawn( + run_train, + args=( + data, + world_size, + cugraph_id, + model, + args.epochs, + args.batch_size, + args.fan_out, + split_idx, + dataset.num_classes, + wall_clock_start, + args.num_layers, + args.seeds_per_call, + ), + nprocs=world_size, + join=True, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py index 77e2ac4..e2d7725 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -128,6 +128,11 @@ def __init__( (None, edge_label_index), ) + # Note reverse of standard convention here + if input_type is not None: + edge_label_index[0] += data[1]._vertex_offsets[input_type[0]] + edge_label_index[1] += data[1]._vertex_offsets[input_type[2]] + self.__input_data = torch_geometric.sampler.EdgeSamplerInput( input_id=torch.arange( edge_label_index[0].numel(), dtype=torch.int64, device="cuda" diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py index 0805653..2effdab 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -188,11 +188,23 @@ def __init__( # Will eventually automatically convert these objects to cuGraph objects. raise NotImplementedError("Currently can't accept non-cugraph graphs") + feature_store, graph_store = data + if compression is None: - compression = "CSR" + compression = "CSR" if graph_store.is_homogeneous else "COO" elif compression not in ["CSR", "COO"]: raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") + if not graph_store.is_homogeneous: + if compression != "COO": + raise ValueError( + "Only COO format is supported for heterogeneous graphs!" + ) + if directory is not None: + raise ValueError( + "Writing to disk is not supported for heterogeneous graphs!" + ) + writer = ( None if directory is None @@ -203,8 +215,6 @@ def __init__( ) ) - feature_store, graph_store = data - if weight_attr is not None: graph_store._set_weight_attr((feature_store, weight_attr)) @@ -221,6 +231,9 @@ def __init__( with_replacement=replace, local_seeds_per_call=local_seeds_per_call, biased=(weight_attr is not None), + heterogeneous=(not graph_store.is_homogeneous), + vertex_type_offsets=graph_store._vertex_offset_array, + num_edge_types=len(graph_store.get_all_edge_attrs()), ), (feature_store, graph_store), batch_size=batch_size, diff --git a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py index 1da2c6d..961ac34 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -181,11 +181,23 @@ def __init__( # Will eventually automatically convert these objects to cuGraph objects. raise NotImplementedError("Currently can't accept non-cugraph graphs") + feature_store, graph_store = data + if compression is None: - compression = "CSR" + compression = "CSR" if graph_store.is_homogeneous else "COO" elif compression not in ["CSR", "COO"]: raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") + if not graph_store.is_homogeneous: + if compression != "COO": + raise ValueError( + "Only COO format is supported for heterogeneous graphs!" + ) + if directory is not None: + raise ValueError( + "Writing to disk is not supported for heterogeneous graphs!" + ) + writer = ( None if directory is None @@ -196,8 +208,6 @@ def __init__( ) ) - feature_store, graph_store = data - if weight_attr is not None: graph_store._set_weight_attr((feature_store, weight_attr)) @@ -214,6 +224,9 @@ def __init__( with_replacement=replace, local_seeds_per_call=local_seeds_per_call, biased=(weight_attr is not None), + heterogeneous=(not graph_store.is_homogeneous), + vertex_type_offsets=graph_store._vertex_offset_array, + num_edge_types=len(graph_store.get_all_edge_attrs()), ), (feature_store, graph_store), batch_size=batch_size, diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py index 4b236f7..926e3a8 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -109,6 +109,8 @@ def __init__( input_nodes, input_id, ) + if input_type is not None: + input_nodes += data[1]._vertex_offsets[input_type] self.__input_data = torch_geometric.sampler.NodeSamplerInput( input_id=torch.arange(len(input_nodes), dtype=torch.int64, device="cuda") diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index bc3d4fd..629d56a 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Iterator, Union, Dict, Tuple +from typing import Optional, Iterator, Union, Dict, Tuple, List from cugraph.utilities.utils import import_optional from cugraph.gnn import DistSampler @@ -122,7 +122,7 @@ def __next__(self): elif isinstance(next_sample, torch_geometric.sampler.HeteroSamplerOutput): col = {} - for edge_type, col_idx in next_sample.col: + for edge_type, col_idx in next_sample.col.items(): sz = next_sample.edge[edge_type].numel() if sz == col_idx.numel(): col[edge_type] = col_idx @@ -153,6 +153,20 @@ def __next__(self): data.set_value_dict("num_sampled_edges", next_sample.num_sampled_edges) # TODO figure out how to set input_id for heterogeneous output + input_type, input_id = next_sample.metadata[0] + data[input_type].input_id = input_id + data[input_type].batch_size = input_id.size(0) + + if len(next_sample.metadata) == 2: + data[input_type].seed_time = next_sample.metadata[1] + elif len(next_sample.metadata) == 4: + ( + data[input_type].edge_label_index, + data[input_type].edge_label, + data[input_type].seed_time, + ) = next_sample.metadata[1:] + else: + raise ValueError("Invalid metadata") else: raise ValueError("Invalid output type") @@ -190,12 +204,18 @@ def __next__(self): self.__base_reader ) + lho_name = ( + "label_type_hop_offsets" + if "label_type_hop_offsets" in self.__raw_sample_data + else "label_hop_offsets" + ) + self.__raw_sample_data["input_offsets"] -= self.__raw_sample_data[ "input_offsets" ][0].clone() - self.__raw_sample_data["label_hop_offsets"] -= self.__raw_sample_data[ - "label_hop_offsets" - ][0].clone() + self.__raw_sample_data[lho_name] -= self.__raw_sample_data[lho_name][ + 0 + ].clone() self.__raw_sample_data["renumber_map_offsets"] -= self.__raw_sample_data[ "renumber_map_offsets" ][0].clone() @@ -216,6 +236,207 @@ def __iter__(self): return self +class HeterogeneousSampleReader(SampleReader): + """ + Subclass of SampleReader that reads heterogeneous output samples + produced by the cuGraph distributed sampler. + """ + + def __init__( + self, + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]], + src_types: "torch.Tensor", + dst_types: "torch.Tensor", + vertex_offsets: "torch.Tensor", + edge_types: List[Tuple[str, str, str]], + vertex_types: List[str], + ): + """ + Constructs a new HeterogeneousSampleReader + + Parameters + ---------- + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + The iterator responsible for loading saved samples produced by + the cuGraph distributed sampler. + src_types: torch.Tensor + Integer source type for each integer edge type. + dst_types: torch.Tensor + Integer destination type for each integer edge type. + vertex_offsets: torch.Tensor + Vertex offsets for each vertex type. Used to de-offset vertices + outputted by the cuGraph sampler and return PyG-compliant vertex + IDs. + edge_types: List[Tuple[str, str, str]] + List of edge types in the graph in order, so they can be + mapped to numeric edge types. + vertex_types: List[str] + List of vertex types, in order so they can be mapped to + numeric vertex types. + """ + + self.__src_types = src_types + self.__dst_types = dst_types + self.__edge_types = edge_types + self.__vertex_types = vertex_types + self.__num_vertex_types = len(vertex_types) + + self.__vertex_offsets = vertex_offsets + + super().__init__(base_reader) + + def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): + num_edge_types = self.__src_types.numel() + fanout_length = raw_sample_data["fanout"].numel() // num_edge_types + + num_sampled_nodes = [ + torch.zeros((fanout_length + 1,), dtype=torch.int64, device="cuda") + for _ in range(self.__num_vertex_types) + ] + + num_sampled_edges = {} + + node = {} + row = {} + col = {} + edge = {} + + input_type = None + + for etype in range(num_edge_types): + pyg_can_etype = self.__edge_types[etype] + + jx = self.__src_types[etype] + index * self.__num_vertex_types + map_ptr_src_beg = raw_sample_data["renumber_map_offsets"][jx] + map_ptr_src_end = raw_sample_data["renumber_map_offsets"][jx + 1] + + map_src = raw_sample_data["map"][map_ptr_src_beg:map_ptr_src_end] + node[pyg_can_etype[0]] = ( + map_src - self.__vertex_offsets[self.__src_types[etype]] + ).cpu() + + kx = self.__dst_types[etype] + index * self.__num_vertex_types + map_ptr_dst_beg = raw_sample_data["renumber_map_offsets"][kx] + map_ptr_dst_end = raw_sample_data["renumber_map_offsets"][kx + 1] + + map_dst = raw_sample_data["map"][map_ptr_dst_beg:map_ptr_dst_end] + node[pyg_can_etype[2]] = ( + map_dst - self.__vertex_offsets[self.__dst_types[etype]] + ).cpu() + + edge_ptr_beg = ( + index * num_edge_types * fanout_length + etype * fanout_length + ) + edge_ptr_end = ( + index * num_edge_types * fanout_length + (etype + 1) * fanout_length + ) + lho = raw_sample_data["label_type_hop_offsets"][ + edge_ptr_beg : edge_ptr_end + 1 + ] + + num_sampled_edges[pyg_can_etype] = (lho).diff() + + eid_i = raw_sample_data["edge_id"][lho[0] : lho[-1]] + + eirx = (index * num_edge_types) + etype + edge_id_ptr_beg = raw_sample_data["edge_renumber_map_offsets"][eirx] + edge_id_ptr_end = raw_sample_data["edge_renumber_map_offsets"][eirx + 1] + + emap = raw_sample_data["edge_renumber_map"][edge_id_ptr_beg:edge_id_ptr_end] + edge[pyg_can_etype] = emap[eid_i] + + col[pyg_can_etype] = raw_sample_data["majors"][lho[0] : lho[-1]] + row[pyg_can_etype] = raw_sample_data["minors"][lho[0] : lho[-1]] + + for hop in range(fanout_length): + vx = raw_sample_data["majors"][: lho[hop + 1]] + if vx.numel() > 0: + num_sampled_nodes[self.__dst_types[etype]][hop + 1] = torch.max( + num_sampled_nodes[self.__dst_types[etype]][hop + 1], + vx.max() + 1, + ) + + vy = raw_sample_data["minors"][: lho[hop + 1]] + if vy.numel() > 0: + num_sampled_nodes[self.__src_types[etype]][hop + 1] = torch.max( + num_sampled_nodes[self.__src_types[etype]][hop + 1], + vy.max() + 1, + ) + + ux = col[pyg_can_etype][: num_sampled_edges[pyg_can_etype][0]] + if ux.numel() > 0: + input_type = pyg_can_etype[2] # can only ever be 1 + + num_sampled_nodes[self.__dst_types[etype]][0] = torch.max( + num_sampled_nodes[self.__dst_types[etype]][0], + (ux.max() + 1).reshape((1,)), + ) + + if input_type is None: + raise ValueError("No input type found!") + + num_sampled_nodes = { + self.__vertex_types[i]: z.diff( + prepend=torch.zeros((1,), dtype=torch.int64, device="cuda") + ).cpu() + for i, z in enumerate(num_sampled_nodes) + } + num_sampled_edges = {k: v.cpu() for k, v in num_sampled_edges.items()} + + input_index = ( + input_type, + raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data[ + "input_offsets" + ][index + 1] + ], + ) + + edge_inverse = ( + ( + raw_sample_data["edge_inverse"][ + (raw_sample_data["input_offsets"][index] * 2) : ( + raw_sample_data["input_offsets"][index + 1] * 2 + ) + ] + ) + if "edge_inverse" in raw_sample_data + else None + ) + + if edge_inverse is None: + metadata = ( + input_index, + None, # TODO this will eventually include time + ) + else: + metadata = ( + input_index, + edge_inverse.view(2, -1), + None, + None, # TODO this will eventually include time + ) + + return torch_geometric.sampler.HeteroSamplerOutput( + node=node, + row=row, + col=col, + edge=edge, + batch=None, + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, + metadata=metadata, + ) + + def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): + if "major_offsets" in raw_sample_data: + raise ValueError( + "CSR format not currently supported for heterogeneous graphs" + ) + else: + return self.__decode_coo(raw_sample_data, index) + + class HomogeneousSampleReader(SampleReader): """ Subclass of SampleReader that reads homogeneous output samples @@ -465,10 +686,15 @@ def sample_from_nodes( ): return HomogeneousSampleReader(reader) else: - # TODO implement heterogeneous sampling - raise NotImplementedError( - "Sampling heterogeneous graphs is currently" - " unsupported in the non-dask API" + edge_types, src_types, dst_types = self.__graph_store._numeric_edge_types + + return HeterogeneousSampleReader( + reader, + src_types=src_types, + dst_types=dst_types, + edge_types=edge_types, + vertex_types=sorted(self.__graph_store._num_vertices().keys()), + vertex_offsets=self.__graph_store._vertex_offset_array, ) def sample_from_edges( @@ -533,8 +759,11 @@ def sample_from_edges( ): return HomogeneousSampleReader(reader) else: - # TODO implement heterogeneous sampling - raise NotImplementedError( - "Sampling heterogeneous graphs is currently" - " unsupported in the non-dask API" + edge_types, src_types, dst_types = self.__graph_store._numeric_edge_types + return HeterogeneousSampleReader( + reader, + src_types=src_types, + dst_types=dst_types, + edge_types=edge_types, + vertex_offsets=self.__graph_store._vertex_offset_array, ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py index 8ee18a8..7938e6f 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -194,3 +194,64 @@ def test_link_neighbor_loader_negative_sampling_uneven(batch_size): elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) for i, batch in enumerate(loader): assert batch.edge_label[0] == 1.0 + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +def test_neighbor_loader_hetero_basic(): + src = torch.tensor([0, 1, 2, 4, 3, 4, 5, 5]) # paper + dst = torch.tensor([4, 5, 4, 3, 2, 1, 0, 1]) # paper + + asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author + adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + + graph_store[("paper", "cites", "paper"), "coo"] = [src, dst] + graph_store[("author", "writes", "paper"), "coo"] = [asrc, adst] + + from cugraph_pyg.loader import NeighborLoader + + loader = NeighborLoader( + (feature_store, graph_store), + num_neighbors=[1, 1, 1, 1], + input_nodes=("paper", torch.tensor([0, 1])), + batch_size=2, + ) + + out = next(iter(loader)) + + assert sorted(out["paper"].n_id.tolist()) == [0, 1, 4, 5] + assert sorted(out["author"].n_id.tolist()) == [0, 1, 3] + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +def test_neighbor_loader_hetero_single_etype(): + src = torch.tensor([0, 1, 2, 4, 3, 4, 5, 5]) # paper + dst = torch.tensor([4, 5, 4, 3, 2, 1, 0, 1]) # paper + + asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author + adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper + + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + + graph_store[("paper", "cites", "paper"), "coo"] = [src, dst] + graph_store[("author", "writes", "paper"), "coo"] = [asrc, adst] + + from cugraph_pyg.loader import NeighborLoader + + loader = NeighborLoader( + (feature_store, graph_store), + num_neighbors=[0, 1, 0, 1], + input_nodes=("paper", torch.tensor([0, 1])), + batch_size=2, + ) + + out = next(iter(loader)) + + assert out["author"].n_id.numel() == 0 + assert out["author", "writes", "paper"].edge_index.numel() == 0 + assert out["author", "writes", "paper"].num_sampled_edges.tolist() == [0, 0]