How to freeze graph in TensorFlow 2.X

Jimmy (xiaoke) Shen
1 min readOct 15, 2020

--

If you are using Keras and want to save a frozen graph in the format of model.pd instead of the model_wights.h5, you may need to freeze the graph and save it.

In the Tensorflow 1.X, you can do this by following the instruction in my previous article.

TensorFlow 2.x does support freezing models and these frozen models should be equivalent to the frozen models used for TensorFlow 1.x. from leimao of [1]

The critical code is

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2# Get keras model
# model = ...
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda inputs: model(inputs))
full_model = full_model.get_concrete_function([tf.TensorSpec(model_input.shape, model_input.dtype) for model_input in model.inputs])
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model) frozen_func.graph.as_graph_def()
# Save frozen graph from frozen ConcreteFunction to hard drive tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir="./frozen_models", name="simple_frozen_graph.pb", as_text=False)

Detail can be seen from [2]

Reference

[1] https://github.com/tensorflow/tensorflow/issues/27614

[2] https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

[3] https://leimao.github.io/blog/Save-Load-Inference-From-TF-Frozen-Graph/

[4] https://gist.github.com/FlorentGuinier/57edf0b644333278dd06c9851f480bc5

--

--