1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
| class DAEGC(nn.Module): def __init__(self, num_features, hidden_size, embedding_size, alpha, num_clusters, v=1): super(DAEGC, self).__init__() self.num_clusters = num_clusters self.v = v
self.gat = GAT(num_features, hidden_size, embedding_size, alpha) self.gat.load_state_dict(torch.load(args.pretrain_path, map_location='cpu'))
self.cluster_layer = Parameter(torch.Tensor(num_clusters, embedding_size)) torch.nn.init.xavier_normal_(self.cluster_layer.data)
def forward(self, x, adj, M): A_pred, z = self.gat(x, adj, M) q = self.get_Q(z)
return A_pred, z, q
def get_Q(self, z): q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v) q = q.pow((self.v + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() return q
def target_distribution(q): weight = q**2 / q.sum(0) return (weight.t() / weight.sum(1)).t()
def trainer(dataset): model = DAEGC(num_features=args.input_dim, hidden_size=args.hidden_size, embedding_size=args.embedding_size, alpha=args.alpha, num_clusters=args.n_clusters).to(device) print(model) optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
dataset = utils.data_preprocessing(dataset) adj = dataset.adj.to(device) adj_label = dataset.adj_label.to(device) M = utils.get_M(adj).to(device)
data = torch.Tensor(dataset.x).to(device) y = dataset.y.cpu().numpy()
with torch.no_grad(): _, z = model.gat(data, adj, M)
kmeans = KMeans(n_clusters=args.n_clusters, n_init=20) y_pred = kmeans.fit_predict(z.data.cpu().numpy()) model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device) eva(y, y_pred, 'pretrain')
for epoch in range(args.max_epoch): model.train() if epoch % args.update_interval == 0: A_pred, z, Q = model(data, adj, M)
q = Q.detach().data.cpu().numpy().argmax(1) eva(y, q, epoch)
A_pred, z, q = model(data, adj, M) p = target_distribution(Q.detach())
kl_loss = F.kl_div(q.log(), p, reduction='batchmean') re_loss = F.binary_cross_entropy(A_pred.view(-1), adj_label.view(-1))
loss = 10 * kl_loss + re_loss
optimizer.zero_grad() loss.backward() optimizer.step()
|