DGL message passing fn.sum fn.mean

Understand the message passing process by playing a toy example

Observations

Build a toy graph

>>> import dgl
Using backend: pytorch
>>> import dgl.function as fn
>>> import networkx as nx
>>> import torch as th
>>> u = th.tensor([0, 0, 3])
>>> v = th.tensor([1, 2, 2])
>>> g = dgl.DGLGraph((u, v))
>>> nx.draw(g.to_networkx(), with_labels=True)
A toy example

Result of fn.sum

>>> g.ndata['h'] = th.randn((g.number_of_nodes(), 2))
>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
>>> g.ndata['h']
tensor([[-1.4749, -2.7265],
[ 0.6077, -0.2449],
[-0.6196, -2.0411],
[ 0.0917, -1.6501]])
>>> g.ndata['h_sum']
tensor([[ 0.0000, 0.0000],
[-1.4749, -2.7265],
[-1.3832, -4.3765],
[ 0.0000, 0.0000]])

Result of fn.mean

>>> g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_mean'))
>>> g.ndata['h_mean']
tensor([[ 0.0000, 0.0000],
[-1.4749, -2.7265],
[-0.6916, -2.1883],
[ 0.0000, 0.0000]])

Add self-loops for each node

>>> u = th.tensor([0, 0, 3, 0, 1, 2, 3])
>>> v = th.tensor([1, 2, 2, 0 ,1, 2, 3])
>>> g = dgl.DGLGraph((u, v))
>>> g.ndata['h'] = th.randn((g.number_of_nodes(), 2))
>>> g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
>>> g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_mean'))
>>> g.ndata['h']
tensor([[-0.3935, 0.6913],
[-0.1614, -0.5214],
[ 0.2733, 0.6520],
[ 0.1839, -0.9459]])
>>> g.ndata['h_sum']
tensor([[-0.3935, 0.6913],
[-0.5549, 0.1699],
[ 0.0637, 0.3975],
[ 0.1839, -0.9459]])
>>> g.ndata['h_mean']
tensor([[-0.3935, 0.6913],
[-0.2775, 0.0849],
[ 0.0212, 0.1325],
[ 0.1839, -0.9459]])

--

--

--

Data Scientist/MLE/SWE @takemobi

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

Is Google becoming the Fisher Price of development?

Insider tips for better Xero integrations

How To Automate Your Development Process ?

Installation of SumoLogic Collector on Linux EC2 machine using system manager run command

A Computer Takes the LSAT: Generating Fact Scenarios

Geo-Firestore Query with Native Query

Why Rubber Duck Debugging is the best way to debug your code

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Jimmy Shen

Jimmy Shen

Data Scientist/MLE/SWE @takemobi

More from Medium

Paper Review: “Scaling Graph Neural Networks with Approximate PageRank” in Proceedings of the 26th…

A Machine Learning Template based on Pytorch Lightning

🌒 The Dark side of Graph Neural Networks

Pairwise Learning for Neural Link Prediction