-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathreddit_quiver.py
160 lines (120 loc) · 5.53 KB
/
reddit_quiver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os.path as osp
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborSampler
from torch_geometric.nn import SAGEConv
import quiver
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')
dataset = Reddit(path)
data = dataset[0]
train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
################################
# Step 1: Using Quiver's sampler
################################
# train_loader = NeighborSampler(data.edge_index, node_idx=data.train_mask,
# sizes=[25, 10], batch_size=1024, shuffle=True,
# num_workers=12) # Original PyG Code
train_loader = torch.utils.data.DataLoader(train_idx,
batch_size=1024,
shuffle=True,
drop_last=True) # Quiver
csr_topo = quiver.CSRTopo(data.edge_index) # Quiver
quiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, sizes=[25, 10], device=0, mode='GPU') # Quiver
subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1],
batch_size=1024, shuffle=False,
num_workers=12)
class SAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(SAGE, self).__init__()
self.num_layers = 2
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, adjs):
# `train_loader` computes the k-hop neighborhood of a batch of nodes,
# and returns, for each layer, a bipartite graph object, holding the
# bipartite edges `edge_index`, the index `e_id` of the original edges,
# and the size/shape `size` of the bipartite graph.
# Target nodes are also included in the source nodes so that one can
# easily apply skip-connections or add self-loops.
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x.log_softmax(dim=-1)
def inference(self, x_all):
pbar = tqdm(total=x_all.size(0) * self.num_layers)
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(dataset.num_features, 256, dataset.num_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
################################
# Step 2: Using Quiver's Feature
################################
# x = data.x.to(device) # Original PyG Code
x = quiver.Feature(rank=0, device_list=[0], device_cache_size="4G", cache_policy="device_replicate", csr_topo=csr_topo) # Quiver
x.from_cpu_tensor(data.x) # Quiver
y = data.y.squeeze().to(device)
def train(epoch):
model.train()
pbar = tqdm(total=int(data.train_mask.sum()))
pbar.set_description(f'Epoch {epoch:02d}')
total_loss = total_correct = 0
############################################
# Step 3: Training the PyG Model with Quiver
############################################
# for batch_size, n_id, adjs in train_loader: # Original PyG Code
for seeds in train_loader: # Quiver
n_id, batch_size, adjs = quiver_sampler.sample(seeds) # Quiver
# `adjs` holds a list of `(edge_index, e_id, size)` tuples.
adjs = [adj.to(device) for adj in adjs]
optimizer.zero_grad()
out = model(x[n_id], adjs)
loss = F.nll_loss(out, y[n_id[:batch_size]])
loss.backward()
optimizer.step()
total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
pbar.update(batch_size)
pbar.close()
loss = total_loss / len(train_loader)
approx_acc = total_correct / int(data.train_mask.sum())
return loss, approx_acc
@torch.no_grad()
def test():
model.eval()
out = model.inference(x)
y_true = y.cpu().unsqueeze(-1)
y_pred = out.argmax(dim=-1, keepdim=True)
results = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
return results
for epoch in range(1, 11):
loss, acc = train(epoch)
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')
train_acc, val_acc, test_acc = test()
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')