# DGL message passing fn.sum fn.mean

Understand the message passing process by playing a toy example

## Observations

Build a toy graph

`>>> import dglUsing backend: pytorch>>> import dgl.function as fn>>> import networkx as nx>>> import torch as th>>> u = th.tensor([0, 0, 3])>>> v = th.tensor([1, 2, 2])>>> g = dgl.DGLGraph((u, v))>>> nx.draw(g.to_networkx(), with_labels=True)`

## Result of fn.sum

`>>> g.ndata['h'] = th.randn((g.number_of_nodes(), 2))>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))>>> g.ndata['h']tensor([[-1.4749, -2.7265],        [ 0.6077, -0.2449],        [-0.6196, -2.0411],        [ 0.0917, -1.6501]])>>> g.ndata['h_sum']tensor([[ 0.0000,  0.0000],        [-1.4749, -2.7265],        [-1.3832, -4.3765],        [ 0.0000,  0.0000]])`

## Result of fn.mean

`>>> g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_mean'))>>> g.ndata['h_mean']tensor([[ 0.0000,  0.0000],        [-1.4749, -2.7265],        [-0.6916, -2.1883],        [ 0.0000,  0.0000]])`

## Add self-loops for each node

`>>> u = th.tensor([0, 0, 3, 0, 1, 2, 3])>>> v = th.tensor([1, 2, 2, 0 ,1, 2, 3])>>> g = dgl.DGLGraph((u, v))>>> g.ndata['h'] = th.randn((g.number_of_nodes(), 2))>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))>>> g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_mean'))>>> g.ndata['h']tensor([[-0.3935,  0.6913],        [-0.1614, -0.5214],        [ 0.2733,  0.6520],        [ 0.1839, -0.9459]])>>> g.ndata['h_sum']tensor([[-0.3935,  0.6913],        [-0.5549,  0.1699],        [ 0.0637,  0.3975],        [ 0.1839, -0.9459]])>>> g.ndata['h_mean']tensor([[-0.3935,  0.6913],        [-0.2775,  0.0849],        [ 0.0212,  0.1325],        [ 0.1839, -0.9459]])`

--

--

## More from Jimmy (xiaoke) Shen

Data Scientist/MLE/SWE @takemobi

Love podcasts or audiobooks? Learn on the go with our new app.