DGL message passing fn.sum fn.mean

Objective

Jimmy (xiaoke) Shen
2 min readMay 27, 2020

Understand the message passing process by playing a toy example

Observations

Build a toy graph

>>> import dgl
Using backend: pytorch
>>> import dgl.function as fn
>>> import networkx as nx
>>> import torch as th
>>> u = th.tensor([0, 0, 3])
>>> v = th.tensor([1, 2, 2])
>>> g = dgl.DGLGraph((u, v))
>>> nx.draw(g.to_networkx(), with_labels=True)
A toy example

Result of fn.sum

>>> g.ndata['h'] = th.randn((g.number_of_nodes(), 2))
>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
>>> g.ndata['h']
tensor([[-1.4749, -2.7265],
[ 0.6077, -0.2449],
[-0.6196, -2.0411],
[ 0.0917, -1.6501]])
>>> g.ndata['h_sum']
tensor([[ 0.0000, 0.0000],
[-1.4749, -2.7265],
[-1.3832, -4.3765],
[ 0.0000, 0.0000]])

Result of fn.mean

>>> g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_mean'))
>>> g.ndata['h_mean']
tensor([[ 0.0000, 0.0000],
[-1.4749, -2.7265],
[-0.6916, -2.1883],
[ 0.0000, 0.0000]])

Add self-loops for each node

>>> u = th.tensor([0, 0, 3, 0, 1, 2, 3])
>>> v = th.tensor([1, 2, 2, 0 ,1, 2, 3])
>>> g = dgl.DGLGraph((u, v))
>>> g.ndata['h'] = th.randn((g.number_of_nodes(), 2))
>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
>>> g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_mean'))
>>> g.ndata['h']
tensor([[-0.3935, 0.6913],
[-0.1614, -0.5214],
[ 0.2733, 0.6520],
[ 0.1839, -0.9459]])
>>> g.ndata['h_sum']
tensor([[-0.3935, 0.6913],
[-0.5549, 0.1699],
[ 0.0637, 0.3975],
[ 0.1839, -0.9459]])
>>> g.ndata['h_mean']
tensor([[-0.3935, 0.6913],
[-0.2775, 0.0849],
[ 0.0212, 0.1325],
[ 0.1839, -0.9459]])

--

--

No responses yet