Pytorch Model transfer
How to transfer parts of one model’s weights to another model
Problem
If you train a model (suppose it has layers of A B C D) on one dataset. And you are going to use some layers (supposed the layer A B C are shared between pre-trained model and new model) of the pre-trained model in a new model (suppose it has layers of A B C E). How can we transfer the weights of ABC from the pre-trained model to the new model?
The problem is illustrated in figure 1.
preparation
save and load model in PyTorch
what is load_state_dict function?
If you read this code about Resnet, you will find the load_state_dict function. It reads pre-trained weights to the model. The function is called by model.load_state_dict. What is the detail of this function?
Solution
From [1]
“The keys of
state_dict
must exactly match the keys returned by this module’sstate_dict()
function.
if name not in the new model, it will raise KeyError
but I guess this may work for you
def load_my_state_dict(self, state_dict):
own_state = self.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
“
I think this link will solve the problem and it works. Please see the last session of this article.
It is trying to solve the same problem as mine:
pretrained_dict: [‘A’, ‘B’, ‘C’, ‘D’]
model_dict: [‘A’, ‘B’, ‘C’, ‘E’]
How to transfer the weights of ‘ABC’ from a pre-trained model to a new model.
What are the keys to the state_dic of a model?
From the above discussion, we can see the model parameters are save in a dictionary. And my question will be how did the define the keys for the model?
From [2], we can pretty much get the answer.
The key is the name that you assign to the variable in the nn.Module, therefore
class test(torch.nn.Module):
def __init__(self):
super(test,self).__init__()
self.conv1 = torch.nn.Conv2d(10,15,10)
self.customconv = torch.nn.Conv2d(100,1000,10)test()
Out[7]:
test(
(conv1): Conv2d(10, 15, kernel_size=(10, 10), stride=(1, 1))
(customconv): Conv2d(100, 1000, kernel_size=(10, 10), stride=(1, 1))
)
if you assigned proper names
you will get a tractable state dict
in this case
q.keys()
Out[17]: odict_keys(['conv1.weight', 'conv1.bias', 'customconv.weight', 'customconv.bias'])
since I assigned different names i get different keys.
if you always use conv1 you will get a wrong and not functional state dict
code
import torch
import torch.nn as nn
import torch.nn.functional as F
class test(torch.nn.Module):
def __init__(self):
super(test,self).__init__()
self.conv1 = torch.nn.Conv2d(10,15,10)
self.customconv = torch.nn.Conv2d(100,1000,10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.customconv(x)))
x = self.pool(F.relu(self.customconv(x)))
return x
model = test()
q = model.state_dict()
print(q.keys())
output
odict_keys(['conv1.weight', 'conv1.bias', 'customconv.weight', 'customconv.bias'])
It works
this link introduce a method to get the pre-trained weights. I verified the idea and it works well.
code
The code is posted on Github.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModelA(torch.nn.Module):
def __init__(self):
super(ModelA,self).__init__()
self.A = torch.nn.Linear(2, 3)
self.B = torch.nn.Linear(3, 4)
self.C = torch.nn.Linear(4, 4)
self.D = torch.nn.Linear(4, 3)
def forward(self, x):
x = F.relu(self.A(x))
x = F.relu(self.B(x))
x = F.relu(self.C(x))
x = F.relu(self.D(x))
return xclass ModelB(torch.nn.Module):
def __init__(self):
super(ModelB,self).__init__()
self.A = torch.nn.Linear(2, 3)
self.B = torch.nn.Linear(3, 4)
self.C = torch.nn.Linear(4, 4)
self.E = torch.nn.Linear(4, 2)
def forward(self, x):
x = F.relu(self.A(x))
x = F.relu(self.B(x))
x = F.relu(self.C(x))
x = F.relu(self.E(x))
return xmodelA = ModelA()
modelA_dict = modelA.state_dict()
print('-'*40)
for key in sorted(modelA_dict.keys()):
parameter = modelA_dict[key]
print(key)
print(parameter.size())
print(parameter)modelB = ModelB()
modelB_dict = modelB.state_dict()
print('-'*40)
for key in sorted(modelB_dict.keys()):
parameter = modelB_dict[key]
print(key)
print(parameter.size())
print(parameter)
print('-'*40)
print("modelB is going to use the ABC layers parameters from modelA")
pretrained_dict = modelA_dict
model_dict = modelB_dict
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
modelB.load_state_dict(model_dict)
modelB_dict = modelB.state_dict()
for key in sorted(modelB_dict.keys()):
parameter = modelB_dict[key]
print(key)
print(parameter.size())
print(parameter)
output
----------------modeA weights------------------------
A.bias
torch.Size([3])
tensor([ 0.4012, -0.3587, 0.6650])
A.weight
torch.Size([3, 2])
tensor([[ 0.5574, 0.4757],
[-0.3795, -0.4850],
[ 0.2248, -0.3578]])
B.bias
torch.Size([4])
tensor([ 0.1353, -0.3448, 0.4272, -0.1463])
B.weight
torch.Size([4, 3])
tensor([[-0.4960, 0.2930, 0.1822],
[-0.4309, -0.4259, -0.3604],
[ 0.2976, 0.2279, -0.3805],
[-0.2423, -0.2915, 0.5130]])
C.bias
torch.Size([4])
tensor([-0.2964, -0.3516, -0.2900, 0.2390])
C.weight
torch.Size([4, 4])
tensor([[ 0.0877, 0.4150, -0.1938, 0.3659],
[-0.3505, 0.1734, -0.1803, 0.2914],
[ 0.3375, -0.2661, 0.4651, 0.0041],
[-0.1866, 0.0055, 0.0230, 0.0502]])
D.bias
torch.Size([3])
tensor([0.2733, 0.3856, 0.2848])
D.weight
torch.Size([3, 4])
tensor([[ 0.4498, 0.4846, -0.2461, 0.1043],
[-0.1462, -0.1684, 0.0155, -0.2861],
[-0.2750, 0.3607, 0.4295, -0.3481]])
-----------------modelB weights before using modelA's-----------------------
A.bias
torch.Size([3])
tensor([-0.2486, -0.3553, -0.3503])
A.weight
torch.Size([3, 2])
tensor([[ 0.1880, -0.6102],
[-0.1288, 0.6273],
[ 0.1040, -0.5014]])
B.bias
torch.Size([4])
tensor([ 0.2349, 0.1911, -0.5200, -0.1111])
B.weight
torch.Size([4, 3])
tensor([[ 0.3223, 0.4178, -0.1244],
[-0.2392, 0.5335, -0.4440],
[-0.4544, 0.3134, 0.1886],
[-0.3317, 0.2892, -0.5672]])
C.bias
torch.Size([4])
tensor([ 0.4484, 0.3125, -0.1636, -0.1316])
C.weight
torch.Size([4, 4])
tensor([[-0.1965, -0.3447, -0.4057, -0.2020],
[-0.3002, 0.0170, -0.0360, 0.2502],
[ 0.3630, -0.2502, 0.2334, -0.1819],
[ 0.1432, 0.1483, -0.2965, -0.0004]])
E.bias
torch.Size([2])
tensor([-0.1594, 0.4471])
E.weight
torch.Size([2, 4])
tensor([[ 0.0461, -0.3409, 0.3723, -0.1613],
[-0.0548, 0.3238, -0.2238, 0.1237]])
-----------------ModelB weights after using modelA's-----------------------
modelB is going to use the ABC layers parameters from modelA
A.bias
torch.Size([3])
tensor([ 0.4012, -0.3587, 0.6650])
A.weight
torch.Size([3, 2])
tensor([[ 0.5574, 0.4757],
[-0.3795, -0.4850],
[ 0.2248, -0.3578]])
B.bias
torch.Size([4])
tensor([ 0.1353, -0.3448, 0.4272, -0.1463])
B.weight
torch.Size([4, 3])
tensor([[-0.4960, 0.2930, 0.1822],
[-0.4309, -0.4259, -0.3604],
[ 0.2976, 0.2279, -0.3805],
[-0.2423, -0.2915, 0.5130]])
C.bias
torch.Size([4])
tensor([-0.2964, -0.3516, -0.2900, 0.2390])
C.weight
torch.Size([4, 4])
tensor([[ 0.0877, 0.4150, -0.1938, 0.3659],
[-0.3505, 0.1734, -0.1803, 0.2914],
[ 0.3375, -0.2661, 0.4651, 0.0041],
[-0.1866, 0.0055, 0.0230, 0.0502]])
E.bias
torch.Size([2])
tensor([-0.1594, 0.4471])
E.weight
torch.Size([2, 4])
tensor([[ 0.0461, -0.3409, 0.3723, -0.1613],
[-0.0548, 0.3238, -0.2238, 0.1237]])
Thanks for reading.