-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_gnn.py
145 lines (119 loc) · 4.58 KB
/
base_gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, global_max_pool
from torch_geometric.data import Batch
from faknow.model.model import AbstractModel
class _BaseGNN(AbstractModel):
"""
base gnn models for GCN, SAGE and GAT
"""
def __init__(self, feature_size: int, hidden_size: int, concat=False):
"""
Args:
feature_size (int): dimension of input node feature
hidden_size (int): Default=128
concat (bool): concat news embedding and graph embedding. Default=False
"""
super(_BaseGNN, self).__init__()
self.feature_size = feature_size
self.hidden_size = hidden_size
self.concat = concat
if self.concat:
self.fc0 = torch.nn.Linear(self.feature_size, self.hidden_size)
self.fc1 = torch.nn.Linear(self.hidden_size * 2, self.hidden_size)
self.fc2 = torch.nn.Linear(self.hidden_size, 2)
def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor,
num_graphs: int):
"""
Args:
x (Tensor): node feature, shape=(num_nodes, feature_size)
edge_index (Tensor): edge index, shape=(2, num_edges)
batch (Tensor): index of graph each node belongs to, shape=(num_nodes,)
num_graphs (int): number of graphs, a.k.a. batch_size
Returns:
Tensor: prediction of being fake, shape=(num_graphs, 2)
"""
edge_attr = None
raw_x = x
x = F.relu(self.conv(x, edge_index, edge_attr))
x = global_max_pool(x, batch)
# whether concat news embedding and graph embedding
if self.concat:
news = torch.stack([
raw_x[(batch == idx).nonzero().squeeze()[0]]
for idx in range(num_graphs)
])
news = F.relu(self.fc0(news))
x = torch.cat([x, news], dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def calculate_loss(self, data: Batch) -> torch.Tensor:
"""
calculate loss via CrossEntropyLoss
Args:
data (Batch): batch data
Returns:
torch.Tensor: loss
"""
output = self.forward(data.x, data.edge_index, data.batch,
data.num_graphs)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, data.y)
return loss
def predict(self, data_without_label: Batch) -> torch.Tensor:
"""
predict the probability of being fake news
Args:
data_without_label (Batch): batch data
Returns:
Tensor: softmax probability, shape=(num_graphs, 2)
"""
output = self.forward(data_without_label.x,
data_without_label.edge_index,
data_without_label.batch,
data_without_label.num_graphs)
return F.softmax(output, dim=1)
class GCN(_BaseGNN):
"""
Semi-Supervised Classification with Graph Convolutional Networks, ICLR 2017
paper: https://openreview.net/forum?id=SJU4ayYgl
code: https://github.com/safe-graph/GNN-FakeNews
"""
def __init__(self, feature_size: int, hidden_size=128):
"""
Args:
feature_size (int): dimension of input node feature
hidden_size (int): Default=128
"""
super().__init__(feature_size, hidden_size, False)
self.conv = GCNConv(self.feature_size, self.hidden_size)
class SAGE(_BaseGNN):
"""
Inductive Representation Learning on Large Graphs, NeurIPS 2017
paper: https://dl.acm.org/doi/10.5555/3294771.3294869
code: https://github.com/safe-graph/GNN-FakeNews
"""
def __init__(self, feature_size: int, hidden_size=128):
"""
Args:
feature_size (int): dimension of input node feature
hidden_size (int): Default=128
"""
super().__init__(feature_size, hidden_size, False)
self.conv = SAGEConv(self.feature_size, self.hidden_size)
class GAT(_BaseGNN):
"""
Graph Attention Networks, ICLR 2018
paper: https://openreview.net/forum?id=rJXMpikCZ
code: https://github.com/safe-graph/GNN-FakeNews
"""
def __init__(self, feature_size: int, hidden_size=128):
"""
Args:
feature_size (int): dimension of input node feature
hidden_size (int): Default=128
"""
super().__init__(feature_size, hidden_size, False)
self.conv = GATConv(self.feature_size, self.hidden_size)