I'm currently trying to create a very basic graph matching/ entity alignment model but the results are really not good on the train set and even worse on the test set. I also tried to add dropout to improve my results but they became even worse So I wanted to know if you could take a look and let me know what I could improve or what you think doesn't really make sense in what I did.
What I did in this model is first use a Siamese GCN on two layers to get embeddings for the nodes of my 2 graphs and I try to improve this model by using a loss and finally I use a hits@1 and hits@10 to see if my model works well on my test data.
Here's my model:
class SiameseGCN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim,dropout=0.4):
super(SiameseGCN, self).__init__()
# First GCN layer
self.conv1 = GCNConv(input_dim, hidden_dim)
# Second GCN layer
self.conv2 = GCNConv(hidden_dim, output_dim)
# Dropout layer
self.dropout = nn.Dropout(dropout)
def forward(self, x1, edge_index1, x2, edge_index2):
# GCN on first graph
x1 = self.conv1(x1, edge_index1)
x1 = torch.relu(x1)
x1 = self.conv2(x1, edge_index1)
x1 = self.dropout(x1)
# GCN on second graph
x2 = self.conv1(x2, edge_index2)
x2 = torch.relu(x2)
x2 = self.conv2(x2, edge_index2)
x2 = self.dropout(x2)
return x1, x2
def siamese_loss(self, embeddings1, embeddings2):
# Compute cosine similarity matrix between embeddings
sim_matrix = torch.matmul(embeddings1, embeddings2.t()) / torch.matmul(torch.norm(embeddings1, dim=1, keepdim=True), torch.norm(embeddings2, dim=1, keepdim=True).t())
# Find index of max similarity for each node in the second graph
_, indices = torch.max(sim_matrix, dim=0)
# Compute matching accuracy
correct_matches = torch.eq(torch.arange(embeddings2.shape[0]), indices).sum().item()
incorrect_matches = embeddings2.shape[0] - correct_matches
print("Number of correct matches: ", correct_matches)
print("Number of incorrect matches: ", incorrect_matches)
matching_acc = correct_matches / embeddings2.shape[0]
# Compute distances between embeddings for both well-matched and poorly-matched pairs
matched_dist = F.pairwise_distance(embeddings1[indices], embeddings2)
mismatched_dist = F.pairwise_distance(embeddings1, embeddings2)
# Compute the loss as a weighted sum of the two distance terms, with a margin of 1.0
loss = torch.mean(torch.max(torch.tensor(0.0), mismatched_dist - matched_dist + 1.0))
return loss
def hits_at_k(self, embeddings1, embeddings2, k=1):
# Compute cosine similarity matrix between embeddings
sim_matrix = torch.matmul(embeddings1, embeddings2.t()) / (torch.norm(embeddings1, dim=1, keepdim=True) * torch.norm(embeddings2, dim=1, keepdim=True).t())
# Sort the similarity matrix in descending order
sorted_sim, indices = torch.sort(sim_matrix, dim=1, descending=True)
# Compute top-k accuracy
top_k = indices[:, :k]
#Computes element-wise equality torch.eq
correct_matches = torch.sum(torch.eq(top_k, torch.arange(embeddings1.shape[0]).unsqueeze(1)), dim=1).float()
num_correct_matches = int(correct_matches.sum().item())
top_k_acc = num_correct_matches / embeddings1.shape[0]
print(f"Number of correct matches in top-{k}: {num_correct_matches}")
return top_k_acc