MindSpore Graph Learning( 三 )


In [1]: import mindspore as msIn [2]: from mindspore_gl import Graph, GraphFieldIn [3]: from mindspore_gl.nn import GNNCellIn [4]: n_nodes = 3In [5]: n_edges = 3In [6]: src_idx = ms.Tensor([0, 1, 2], ms.int32)In [7]: dst_idx = ms.Tensor([1, 2, 0], ms.int32)In [8]: graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges)In [9]: node_feat = ms.Tensor([[1], [2], [3]], ms.float32)In [10]: class TestSetVertexAttr(GNNCell):...:def construct(self, x, y, g: Graph):...:g.set_src_attr({"hs": x})...:g.set_dst_attr({"hd": y})...:return [v.hd for v in g.dst_vertex] * [u.hs for u in g.src_vertex]...: In [11]: ret = TestSetVertexAttr()(node_feat[src_idx], node_feat[dst_idx], *graph_field.get_graph()).asnumpy().tolist()--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|def construct(self, x, y, g: Graph):1||1def construct(||||self,||||x,||||y,||||src_idx,||||dst_idx,||||n_nodes,||||n_edges,||||UNUSED_0=None,||||UNUSED_1=None,||||UNUSED_2=None||||):||||2SCATTER_ADD = ms.ops.TensorScatterAdd()||||3SCATTER_MAX = ms.ops.TensorScatterMax()||||4SCATTER_MIN = ms.ops.TensorScatterMin()||||5GATHER = ms.ops.Gather()||||6ZEROS = ms.ops.Zeros()||||7FILL = ms.ops.Fill()||||8MASKED_FILL = ms.ops.MaskedFill()||||9IS_INF = ms.ops.IsInf()||||10SHAPE = ms.ops.Shape()||||11RESHAPE = ms.ops.Reshape()||||12scatter_src_idx = RESHAPE(src_idx, (SHAPE(src_idx)[0], 1))||||13scatter_dst_idx = RESHAPE(dst_idx, (SHAPE(dst_idx)[0], 1))||g.set_src_attr({'hs': x})2||14hs, = [x]||g.set_dst_attr({'hd': y})3||15hd, = [y]||return [v.hd for v in g.dst_vertex] * [u.hs for u in g.src_vertex]4||16return hd * hs|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------In [12]: print (ret)[[2.0], [6.0], [3.0]]从这个结果中 , 我们获得的是三条边两头的节点值的积 。除了节点id和节点值之外,mindspore-gl还支持了一些如近邻节点、节点的度等参数的获?。梢圆慰既缦峦计故镜哪谌荩ㄍ计醋杂诓慰剂唇?):

MindSpore Graph Learning

文章插图
除了基本的API接口之外,还可以学习下mindspore-gl的使用中有可能出现的报错信息:
MindSpore Graph Learning

文章插图
在mindspore-gl这一个框架中 , 还有一个对于大型数据来说非常有用的功能,当然,在文章这里只是放一下大概用法,因为暂时没有遇到这种使用的场景 。那就是把一个大型的图网络根据近邻的数量去拆分成不同大小的数据块进行存储和运算 。这样做一方面可以避免动态的shape出现 , 因为网络可能随时都在改变 。另一方面本身图的近邻数大部分就不是均匀分布的,有少部分特别的密集,而更多的情况是一些比较稀疏的图,那么这个时候如果要固定shape的话,就只能padding到较大数量的那一个维度,这样一来就无形之中浪费了巨大的存储空间 。这种分块模式的存储,能够最大限度上减小显存的占用,同时还能够提高运算的速度 。
MindSpore Graph Learning

文章插图
MindSpore Graph Learning

文章插图
那么最后我们再展示一个聚合的简单案例,其实就是获取节点的近邻节点值的加和:
import mindspore as msfrom mindspore import opsfrom mindspore_gl import Graph, GraphFieldfrom mindspore_gl.nn import GNNCelln_nodes = 3n_edges = 3src_idx = ms.Tensor([0, 1, 2, 3, 4], ms.int32)dst_idx = ms.Tensor([1, 2, 0, 1, 2], ms.int32)graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges)node_feat = ms.Tensor([[1], [2], [3], [4], [5]], ms.float32)class GraphConvCell(GNNCell):def construct(self, x, y, g: Graph):g.set_src_attr({"hs": x})g.set_dst_attr({"hd": y})return [g.sum([u.hs for u in v.innbs]) for v in g.dst_vertex]ret = GraphConvCell()(node_feat[src_idx], node_feat[dst_idx], *graph_field.get_graph()).asnumpy().tolist()print (ret)那么这里只要使用一个graph.sum这样的接口就可以实现,非常的易写方便,代码可读性很高 。
$ python3 test_msgl_01.py--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|def construct(self, x, y, g: Graph):1||1def construct(||||self,||||x,||||y,||||src_idx,||||dst_idx,||||n_nodes,||||n_edges,||||UNUSED_0=None,||||UNUSED_1=None,||||UNUSED_2=None||||):||||2SCATTER_ADD = ms.ops.TensorScatterAdd()||||3SCATTER_MAX = ms.ops.TensorScatterMax()||||4SCATTER_MIN = ms.ops.TensorScatterMin()||||5GATHER = ms.ops.Gather()||||6ZEROS = ms.ops.Zeros()||||7FILL = ms.ops.Fill()||||8MASKED_FILL = ms.ops.MaskedFill()||||9IS_INF = ms.ops.IsInf()||||10SHAPE = ms.ops.Shape()||||11RESHAPE = ms.ops.Reshape()||||12scatter_src_idx = RESHAPE(src_idx, (SHAPE(src_idx)[0], 1))||||13scatter_dst_idx = RESHAPE(dst_idx, (SHAPE(dst_idx)[0], 1))||g.set_src_attr({'hs': x})2||14hs, = [x]||g.set_dst_attr({'hd': y})3||15hd, = [y]||return [g.sum([u.hs for u in v.innbs]) for v in g.dst_vertex]4||16SCATTER_INPUT_SNAPSHOT1 = GATHER(hs, src_idx, 0)||||17return SCATTER_ADD(||||ZEROS(||||(n_nodes,) + SHAPE(SCATTER_INPUT_SNAPSHOT1)[1:],||||SCATTER_INPUT_SNAPSHOT1.dtype||||),||||scatter_dst_idx,||||SCATTER_INPUT_SNAPSHOT1||||)|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------[[3.0], [5.0], [7.0]]

推荐阅读