Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification

背景消息传递模型(Message Passing Model)基于拉普拉斯平滑假设(领居是相似的) , 试图聚合图中的邻居的信息来获取足够的依据,以实现更鲁棒的半监督节点分类 。
图神经网络(Graph Neural Networks, GNN)和标签传播算法(Label Propagation, LPA)均为消息传递算法,其中GNN主要基于传播特征来提升预测效果,而LPA基于迭代式的标签传播来作预测 。
【Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification】一些工作要么用LPA对GNN预测结果做后处理,要么用LPA对GNN进行正则化 。但是 , 它们仍不能直接将GNN和LPA有效地整合到消息传递模型中 。
为解决这个问题,本文提出了统一消息传递模型(UNIMP)[1],它可以在训练和推理时结合特征和标签传播 。UniMP基于两个简单而有效的想法:

  • 将特征嵌入和标签嵌入同时作为输入信息进行传播
  • 随机掩码部分标签信息,并在训练时对其进行预测
UniMP在概念上统一了特征传播和标签传播,具有强大的经验能力 。
Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification

文章插图
实现关键部分
  • 将标签进行嵌入(原有的C类One-hot标签,通过线性变换成与原始节点特征相同的维度) 。
  • 然后,将标签嵌入和节点特征相加作为GNN输入 。
为避免训练时使用标签导致标签泄露,这里使用了掩码标签训练的策略 。每个Epoch随机将训练集中部分节点的标签置(掩码)0(视为训练监督信号),然后利用节点特征 \(\mathbf{X}\) 和 \(\mathbf{A}\)以及剩余的标签去预测被掩码的标签) 。
模型部分UniMP中使用了GraphTransformer(Transformer中的Q、K、V注意力形式,加上边特征),同时引入了H-GCN的门控残差机制来缓解过平滑 。
个人实验将标签作为输入,在ArixV数据集节点分类任务上,能在小数点后第2位提升接近2个点 。
在论文BOT[2]中也对标签作为输入做了阐述,其作者还发表了相应的论文来论证标签作为输入的有效性的原因 。
总结标签有效的直觉就是,在图上的节点分类任务中,邻居标签也是预测目标节点标签的关键特征(这也和标签传播的思想一致)
标签嵌入和掩码标签预测是提升节点分类任务简单有效的方法 。
参考文献
[1] Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification[2] Bag of Tricks for Node Classification with Graph Neural Networks
2022-10-29 11:10:13 星期六

    推荐阅读