How to freeze Keras or Tensorflow 1.X graph
3 min readOct 15, 2020
What is the problem?
Keras is easy to save the model by only saving weights as model_weights.h5. What if you need model.pb?
When naively saving the model with
keras.callbacks.ModelCheckpoint(self.checkpoint_path,verbose=0, save_weights_only=False),
Compare the difference to this
keras.callbacks.ModelCheckpoint(self.checkpoint_path,verbose=0, save_weights_only=True),
The first one will save the whole model: architecture and the model weights
the second one will save only the weights.
You may get some error like this for the first one.
Traceback (most recent call last):File "sunrgbd.py", line 740, in <module>input_channels=INPUT_CHANNELS_)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/mask_rcnn-2.1-py3.7.egg/mrcnn/model.py", line 2522, in trainFile "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapperreturn func(*args, **kwargs)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/training.py", line 1732, in fit_generatorinitial_epoch=initial_epoch)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/training_generator.py", line 260, in fit_generatorcallbacks.on_epoch_end(epoch, epoch_logs)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/callbacks/callbacks.py", line 152, in on_epoch_endcallback.on_epoch_end(epoch, logs)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/callbacks/callbacks.py", line 730, in on_epoch_endself.model.save(filepath, overwrite=True)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/network.py", line 1152, in savesave_model(self, filepath, overwrite, include_optimizer)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/saving.py", line 449, in save_wrappersave_function(obj, filepath, overwrite, *args, **kwargs)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/saving.py", line 541, in save_model_serialize_model(model, h5dict, include_optimizer)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/saving.py", line 129, in _serialize_modelmodel_config['config'] = model.get_config()File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/site-packages/keras/engine/network.py", line 950, in get_configreturn copy.deepcopy(config)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 216, in _deepcopy_listappend(deepcopy(a, memo))File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 221, in _deepcopy_tupley = [deepcopy(a, memo) for a in x]File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 221, in <listcomp>y = [deepcopy(a, memo) for a in x]File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 221, in _deepcopy_tupley = [deepcopy(a, memo) for a in x]File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 221, in <listcomp>y = [deepcopy(a, memo) for a in x]File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 180, in deepcopyy = _reconstruct(x, memo, *rv)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 281, in _reconstructstate = deepcopy(state, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 180, in deepcopyy = _reconstruct(x, memo, *rv)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 281, in _reconstructstate = deepcopy(state, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 180, in deepcopyy = _reconstruct(x, memo, *rv)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 281, in _reconstructstate = deepcopy(state, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 180, in deepcopyy = _reconstruct(x, memo, *rv)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 281, in _reconstructstate = deepcopy(state, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 150, in deepcopyy = copier(x, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 241, in _deepcopy_dicty[deepcopy(key, memo)] = deepcopy(value, memo)File "/Users/jimmy/anaconda3/envs/faceana/lib/python3.7/copy.py", line 169, in deepcopyrv = reductor(4)TypeError: can't pickle _thread.RLock objects
How to solve the problem?
get a frozen graph and then save it.[1]
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
Then [1]
from keras import backend as K
# Create, compile and train model...
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)
You are done. I did it by following [1] and it works pretty well for TensorFlow 1.14.
Reference
[1] https://stackoverflow.com/questions/45466020/how-to-export-keras-h5-to-tensorflow-pb