图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用

天天见闻 天天见闻 2023-03-19 软件 阅读: 76
摘要: 图形的节点属性由文本组成,该模型可以同时对文本语义和图形结构信息进行建模。中心节点文本与所有邻居节点文本拼接在一起,并且只能使用ID特征的GraphSAGE对图中的结构信息进行建模。ERNIESage在PGL的消息传递范式中很容易实现。“模型文件”和“构造”部分,将导入的文件组成一个地物。文本显示为该节点的node feature(节点特征),dataset/-此文件夹包含数据读取的代码。
1.ERNIESage运行实例介绍(1.8x版本)

本项目原链接:https://aistudio.baidu.com/aistudio/projectdetail/5097085?contributionType=1

本项目主要是为了直接提供一个可以运行ERNIESage模型的环境,

https://github.com/PaddlePaddle/PGL/blob/develop/examples/erniesage/README.md

在很多工业应用中,往往出现如下图所示的一种特殊的图:Text Graph。顾名思义,图的节点属性由文本构成,而边的构建提供了结构信息。如搜索场景下的Text Graph,节点可由搜索词、网页标题、网页正文来表达,用户反馈和超链信息则可构成边关系。

ERNIESage 由PGL团队提出,是ERNIE SAmple aggreGatE的简称,该模型可以同时建模文本语义与图结构信息,有效提升 Text Graph 的应用效果。其中 ERNIE 是百度推出的基于知识增强的持续学习语义理解框架。

ERNIESage 是 ERNIE 与 GraphSAGE 碰撞的结果,是 ERNIE SAmple aggreGatE 的简称,它的结构如下图所示,主要思想是通过 ERNIE 作为聚合函数(Aggregators),建模自身节点和邻居节点的语义与结构关系。ERNIESage 对于文本的建模是构建在邻居聚合的阶段,中心节点文本会与所有邻居节点文本进行拼接;然后通过预训练的 ERNIE 模型进行消息汇聚,捕捉中心节点以及邻居节点之间的相互关系;最后使用 ERNIESage 搭配独特的邻居互相看不见的 Attention Mask 和独立的 Position Embedding 体系,就可以轻松构建 TextGraph 中句子之间以及词之间的关系。

使用ID特征的GraphSAGE只能够建模图的结构信息,而单独的ERNIE只能处理文本信息。通过PGL搭建的图与文本的桥梁,ERNIESage能够很简单的把GraphSAGE以及ERNIE的优点结合一起。以下面TextGraph的场景,ERNIESage的效果能够比单独的ERNIE以及GraphSAGE模型都要好。

ERNIESage可以很轻松地在PGL中的消息传递范式中进行实现,目前PGL在github上提供了3个版本的ERNIESage模型:

  • ERNIESage v1: ERNIE 作用于text graph节点上;
  • ERNIESage v2: ERNIE 作用在text graph的边上;
  • ERNIESage v3: ERNIE 作用于一阶邻居及起边上;

主要会针对ERNIESageV1和ERNIESageV2版本进行一个介绍。

1.1算法实现

可能有同学对于整个项目代码文件都不太了解,因此这里会做一个比较简单的讲解。

核心部分包含:

  • 数据集部分
  1. data.txt - 简单的输入文件,格式为每行query \t answer,可作简单的运行实例使用。
  • 模型文件和配置部分
  1. ernie_config.json - ERNIE模型的配置文件。
  2. vocab.txt - ERNIE模型所使用的词表。
  3. ernie_base_ckpt/ - ERNIE模型参数。
  4. config/ - ERNIESage模型的配置文件,包含了三个版本的配置文件。
  • 代码部分
  1. local_run.sh - 入口文件,通过该入口可完成预处理、训练、infer三个步骤。
  2. preprocessing文件夹 - 包含dump_graph.py, tokenization.py。在预处理部分,我们首先需要进行建图,将输入的文件构建成一张图。由于我们所研究的是Text Graph,因此节点都是文本,我们将文本表示为该节点对应的node feature(节点特征),处理文本的时候需要进行切字,再映射为对应的token id。
  3. dataset/ - 该文件夹包含了数据ready的代码,以便于我们在训练的时候将训练数据以batch的方式读入。
  4. models/ - 包含了ERNIESage模型核心代码。
  5. train.py - 模型训练入口文件。
  6. learner.py - 分布式训练代码,通过train.py调用。
  7. infer.py - infer代码,用于infer出节点对应的embedding。
  • 评价部分
  1. build_dev.py - 用于将我们的验证集修改为需要的格式。
  2. mrr.py - 计算MRR值。

要在这个项目中运行模型其实很简单,只要运行下方的入口命令就ok啦!但是,需要注意的是,由于ERNIESage模型比较大,所以如果AIStudio中的CPU版本运行模型容易出问题。因此,在运行部署环境时,建议选择GPU的环境。

另外,如果提示出现了GPU空间不足等问题,我们可以通过调小对应yaml文件中的batch_size来调整,也可以修改ERNIE模型的配置文件ernie_config.json,将num_hidden_layers设小一些。在这里,我仅提供了ERNIESageV2版本的gpu运行过程,如果同学们想运行其他版本的模型,可以根据需要修改下方的命令。

运行完毕后,会产生较多的文件,这里进行简单的解释。

  1. workdir/ - 这个文件夹主要会存储和图相关的数据信息。
  2. output/ - 主要的输出文件夹,包含了以下内容:(1)模型文件,根据config文件中的save_per_step可调整保存模型的频率,如果设置得比较大则可能训练过程中不会保存模型; (2)last文件夹,保存了停止训练时的模型参数,在infer阶段我们会使用这部分模型参数;(3)part-0文件,infer之后的输入文件中所有节点的Embedding输出。

为了可以比较清楚地知道Embedding的效果,我们直接通过MRR简单判断一下data.txt计算出来的Embedding结果,此处将data.txt同时作为训练集和验证集。

1.2 核心模型代码讲解

首先,我们可以通过查看models/model_factory.py来判断在本项目有多少种ERNIESage模型。

from models.base import BaseGNNModel
from models.ernie import ErnieModel
from models.erniesage_v1 import ErnieSageModelV1
from models.erniesage_v2 import ErnieSageModelV2
from models.erniesage_v3 import ErnieSageModelV3 class Model(object):
@classmethod
def factory(cls, config):
name = config.model_type
if name == "BaseGNNModel":
return BaseGNNModel(config)
if name == "ErnieModel":
return ErnieModel(config)
if name == "ErnieSageModelV1":
return ErnieSageModelV1(config)
if name == "ErnieSageModelV2":
return ErnieSageModelV2(config)
if name == "ErnieSageModelV3":
return ErnieSageModelV3(config)
else:
raise ValueError

可以看到一共有ERNIESage模型一共有3个版本,另外我们也提供了基本的GNN模型和ERNIE模型,感兴趣的同学可以自行查阅。

接下来,我主要会针对ERNIESageV1和ERNIESageV2这两个版本的模型进行关键部分的讲解,主要的不同其实就是消息传递机制(Message Passing)部分的不同。

# ERNIESageV1的Message Passing代码
# 查找路径:erniesage_v1.py(__call__中的self.gnn_layers) -> base.py(BaseNet类中的gnn_layers方法) -> message_passing.py # erniesage_v1.py
def __call__(self, graph_wrappers):
inputs = self.build_inputs()
feature = self.build_embedding(graph_wrappers, inputs[-1]) # 将节点的文本信息利用ERNIE模型建模,生成对应的Embedding作为feature
features = self.gnn_layers(graph_wrappers, feature) # GNN模型的主要不同,消息传递机制入口
outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
outputs.append(src_real_index)
return inputs, outputs # base.py -> BaseNet
def gnn_layers(self, graph_wrappers, feature):
features = [feature]
initializer = None
fc_lr = self.config.lr / 0.001
for i in range(self.config.num_layers):
if i == self.config.num_layers - 1:
act = None
else:
act = "leaky_relu"
feature = get_layer(
self.config.layer_type, # 对于ERNIESageV1, 其layer_type="graphsage_sum",可以到config文件夹中查看
graph_wrappers[i],
feature,
self.config.hidden_size,
act,
initializer,
learning_rate=fc_lr,
name="%s_%s" % (self.config.layer_type, i))
features.append(feature)
return features # message_passing.py
def graphsage_sum(gw, feature, hidden_size, act, initializer, learning_rate, name):
"""doc"""
msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
neigh_feature = gw.recv(msg, sum_recv) # Recv
self_feature = feature
self_feature = fluid.layers.fc(self_feature,
hidden_size,
act=act,
param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer,
learning_rate=learning_rate),
bias_attr=name+"_l.b_0"
)
neigh_feature = fluid.layers.fc(neigh_feature,
hidden_size,
act=act,
param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer,
learning_rate=learning_rate),
bias_attr=name+"_r.b_0"
)
output = fluid.layers.concat([self_feature, neigh_feature], axis=1)
output = fluid.layers.l2_normalize(output, axis=1)
return output

通过上述代码片段可以看到,关键的消息传递机制代码就是graphsage_sum函数,其中send、recv部分如下。

def copy_send(src_feat, dst_feat, edge_feat):
"""doc"""
return src_feat["h"] msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
neigh_feature = gw.recv(msg, sum_recv) # Recv

通过代码可以看到,ERNIESageV1版本,其主要是针对节点邻居,直接将当前节点的邻居节点特征求和。再看到graphsage_sum函数中,将邻居节点特征进行求和后,得到了neigh_feature。随后,我们将节点本身的特征self_feature和邻居聚合特征neigh_feature通过fc层后,直接concat起来,从而得到了当前gnn layer层的feature输出。

1.2.2ERNIESageV2关键代码

ERNIESageV2的消息传递机制代码主要在erniesage_v2.py和message_passing.py,相对ERNIESageV1来说,代码会相对长了一些。

为了使得大家对下面有关ERNIE模型的部分能够有所了解,这里先贴出ERNIE的主模型框架图。

具体的代码解释可以直接看注释。

# ERNIESageV2的Message Passing代码

# 下面的函数都在erniesage_v2.py的ERNIESageV2类中
# ERNIESageV2的调用函数
def __call__(self, graph_wrappers):
inputs = self.build_inputs()
feature = inputs[-1]
features = self.gnn_layers(graph_wrappers, feature)
outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
outputs.append(src_real_index)
return inputs, outputs # 进入self.gnn_layers函数
def gnn_layers(self, graph_wrappers, feature):
features = [feature] initializer = None
fc_lr = self.config.lr / 0.001 for i in range(self.config.num_layers):
if i == self.config.num_layers - 1:
act = None
else:
act = "leaky_relu" feature = self.gnn_layer(
graph_wrappers[i],
feature,
self.config.hidden_size,
act,
initializer,
learning_rate=fc_lr,
name="%s_%s" % ("erniesage_v2", i))
features.append(feature)
return features
接下来会进入ERNIESageV2主要的代码部分。

可以看到,在ernie_send函数用于将我们的邻居信息发送到当前节点。在ERNIESageV1中,我们在Send阶段对邻居节点通过ERNIE模型得到Embedding后,再直接求和,实际上当前节点和邻居节点之间的文本信息在消息传递过程中是没有直接交互的,直到最后才**concat**起来;而ERNIESageV2中,在Send阶段,源节点和目标节点的信息会直接concat起来,通过ERNIE模型得到一个统一的Embedding,这样就得到了源节点和目标节点的一个信息交互过程,这个部分可以查看下面的ernie_send函数。

gnn_layer函数中包含了三个函数:
1. ernie_send: 将src和dst节点对应文本concat后,过Ernie后得到需要的msg,更加具体的解释可以看下方代码注释。
2. build_position_ids: 主要是为了创建位置ID,提供给Ernie,从而可以产生position embeddings。
3. erniesage_v2_aggregator: gnn_layer的入口函数,包含了消息传递机制,以及聚合后的消息feature处理过程。
# 进入self.gnn_layer函数
def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name):
def build_position_ids(src_ids, dst_ids): # 此函数用于创建位置ID,可以对应到ERNIE框架图中的Position Embeddings
# ...
pass
def ernie_send(src_feat, dst_feat, edge_feat):
"""doc"""
# input_ids,可以对应到ERNIE框架图中的Token Embeddings
cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1)
src_ids = L.concat([cls, src_feat["term_ids"]], 1)
dst_ids = dst_feat["term_ids"]
term_ids = L.concat([src_ids, dst_ids], 1) # sent_ids,可以对应到ERNIE框架图中的Segment Embeddings
sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) # position_ids,可以对应到ERNIE框架图中的Position Embeddings
position_ids = build_position_ids(src_ids, dst_ids) term_ids.stop_gradient = True
sent_ids.stop_gradient = True
ernie = ErnieModel( # ERNIE模型
term_ids, sent_ids, position_ids,
config=self.config.ernie_config)
feature = ernie.get_pooled_output() # 得到发送过来的msg,该msg是由src节点和dst节点的文本特征一起过ERNIE后得到的embedding
return feature
def erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name):
feature = L.unsqueeze(feature, [-1])
msg = gw.send(ernie_send, nfeat_list=[("term_ids", feature)]) # Send
neigh_feature = gw.recv(msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum")) # Recv,直接将发送来的msg根据dst节点来相加。 # 接下来的部分和ERNIESageV1类似,将self_feature和neigh_feature通过concat、normalize后得到需要的输出。
term_ids = feature
cls = L.fill_constant_batch_size_like(term_ids, [-1, 1, 1], "int64", 1)
term_ids = L.concat([cls, term_ids], 1)
term_ids.stop_gradient = True
ernie = ErnieModel(
term_ids, L.zeros_like(term_ids),
config=self.config.ernie_config)
self_feature = ernie.get_pooled_output()
self_feature = L.fc(self_feature,
hidden_size,
act=act,
param_attr=F.ParamAttr(name=name + "_l.w_0",
learning_rate=learning_rate),
bias_attr=name+"_l.b_0"
)
neigh_feature = L.fc(neigh_feature,
hidden_size,
act=act,
param_attr=F.ParamAttr(name=name + "_r.w_0",
learning_rate=learning_rate),
bias_attr=name+"_r.b_0"
)
output = L.concat([self_feature, neigh_feature], axis=1)
output = L.l2_normalize(output, axis=1)
return output
return erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name)
2.总结

通过以上两个版本的模型代码简单的讲解,我们可以知道他们的不同点,其实主要就是在消息传递机制的部分有所不同。ERNIESageV1版本只作用在text graph的节点上,在传递消息(Send阶段)时只考虑了邻居本身的文本信息;而ERNIESageV2版本则作用在了边上,在Send阶段同时考虑了当前节点和其邻居节点的文本信息,达到更好的交互效果。

总结

以上是真正的电脑专家为你收集整理的图神经网络之预训练大模型结合:ERNIESage在链接预测任务应用的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得真正的电脑专家网站内容还不错,欢迎将真正的电脑专家推荐给好友。

其他相关
思科交换机命令大全,含巡检命令,网工建议收藏!

思科交换机命令大全,含巡检命令,网工建议收藏!

作者: 天天见闻 时间:2022-09-25 阅读: 192
switch#logging synchronous 阻止控制台信息覆盖命令行上的输入。switch#switchport trunk encapsulation dotlq 设置vlan中继协议。switch#no switchport mode 或 禁用干线。switch#no spanning-tree vlan priority 将交换机的优先级恢复默认值。switch#spanning-tree vlan cost0-200000000 指定端口成本。switch#spanning-tree portfast 配置速端口如pc机。起用了该命令后在进行优先级更改。switch#show spanning-tree interface detail 浏览详细生成树端口配置信息。switch#interface range fastemet0/1-2 将fastemet0/1和0/2口捆绑...
勇士消息更新:格林伤情出炉,还对新星表不满,名记看衰交易可能

勇士消息更新:格林伤情出炉,还对新星表不满,名记看衰交易可能

作者: 天天见闻 时间:2022-02-23 阅读: 1738
扣篮大赛毫无疑问是全明星周末最能吸引球迷的赛事,而本赛季将由勇士悍将胡安-托斯卡诺-安德森参加扣篮大赛,不过勇士核心内线德雷蒙德-格林对此并不高兴。在自己节目中,格林表达对勇士仅有一名球员参加扣篮大赛表示不满,因为他认为乔纳森-库明加也应该参加格林,但后者对此并不感兴趣:。另一方面,格林是能入选全明星替补席阵容的球员,不过格林在接受采访时明确表示自己因为伤病问题无法出战全明星:...
一级调研员和二级调研员的区别

一级调研员和二级调研员的区别

作者: 天天见闻 时间:2022-02-26 阅读: 799
调研员是国家公务员的一种,属于非领导职务,只在市(地)级以上国家行政机关设置,当任县处级副职领导职务或者副调研员四年以上才可以有调研员的资格。中国实行公务员职务与职级并行制度二级调研员二级调研员,根据公务员职位类别和职责设置公务员领导职务、职级序列。调研员一般不负责实际的工作,国家公务员的级别是15级,调研员属于7到10级,助理调研员属于8到11级。...
A4纸标准打印像素是 多少。?

A4纸标准打印像素是 多少。?

作者: 天天见闻 时间:2022-02-28 阅读: 736
A4纸标准打印像素是 多少.?打满一张A4纸大概多少像素?A4纸 使用像素有多大 - : A4纸大小为21cm*29.7cm.若采用普通屏幕分辨率72dpi 则需要的总像素为595*842=500990像素,也就是50万像素左右. 若需要进行普通答应,采用的分辨率为150dpi 则需要的总像素为1240*1754=2174960,约为217万像素. 当然,若像素更高,则打印质量能更好,不过500万像素基本上是极限了,现在高分辨打印的分辨率为300dpi....
《900重案追凶》在线观看和下载

《900重案追凶》在线观看和下载

作者: 天天见闻 时间:2022-03-01 阅读: 541
《900重案追凶》在线观看和下载基本信息900重案追凶幕后花絮首播时每集播映完毕时,观众可致电热线竞猜每集问题赢取奖品。900重案追凶资源介绍900重案追凶在线观看资源:目前清晰度未知,请您自己尝试后再决定是否需要。900重案追凶下载资源:目前清晰度未知,请您自己尝试后再决定是否需要。2,算是一部高评分电视剧作品,亲们,能有此分数的电影也为数不多呀,推荐大家值得观看。...
海运集装箱的货物种类、货柜尺寸及估算装箱尺寸方法

海运集装箱的货物种类、货柜尺寸及估算装箱尺寸方法

作者: 天天见闻 时间:2022-03-02 阅读: 911
2、货物重量超过集装箱的最大承重重量的,这种货物不能通过集装箱海运。3、特大尺寸的货物不能使用集装箱海运,一些大的配件超高、超宽,这些货物只能通过散货船放在船舱或者甲板上来运输。可选用办理集装箱海运的货物参考:在集装箱货物运输中,为了船、货、箱的安全,必须根据货物的性质、种类、容积、重量和形状来选择适当的集装箱;否则,不仅对某些货物不能承运,而且也会因选用不当而导致货损。估算装箱尺寸的方法...
我来说两句

年度爆文