Simon Szustkowski

Ein Blog über alles, was mir gerade so durch den Kopf geht

Jan 12, 2018 - Comments

Tensorflow - Connect two graphs

This post is in english, since the machine learning community is quite international. Don’t be afraid, i won’t resume blogging just in english in the future, but merely mix both german and english, as it seems appropriate.

Recently, i was struggling with loading two pre-trained model graphs in Tensorflow and connect them for finetuning. “Connect” means: The output of the first graph should be the input of the second graph. Both subgraphs should merge into a supergraph and get executed together. Since i wasn’t successful at first (finding helpful posts wasn’t very easy, the more common scenarios were “Run two graphs in different sessions” or similar), i want to write down my results here.

At first, make sure that your subgraphs have input and output nodes which are easy to find, e.g. have memorizable names. It can help to explicitly create such nodes with the help of tf.identity(). Also, make sure that your trainable variables are in a common tf.variable_scope(), so that your trainable variables of the first subgraph are in a scope that is completely disjunct with the scope of the second subgraph. This is very important.

Pre-train your subgraphs, and save them to disk with tf.train.Saver(). Since the only important variables to save are your trained variables, you can restrict the variables to save to the tf.variable_scope() i have mentioned earlier, so basically this is your saver call:

saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='MY_TRAINABLE_VARIABLE_SCOPE'))
saver.save(sess=sess, save_path=SOME_PATH)

You do this with both subgraphs, of course.

Now, when we have the two saved models, we want to restore and connect them. For connecting, there is a handy feature called input_map. It is an argument dict for tf.train.import_meta_graph. Say, the input node of graph 1 is called “g1_input”, the output node is “g1_output” and similar for graph 2. We also have a function which implements a queue runner for input data, which stores the input data in “input_data” (we have to remap the input as well, since we don’t want to reuse the input data of the pre-training.)

This is basically our code then for restoring the graph structure:

# Read data input, and store it in this variable we will map to the graph's input later
input_data = read_input_data()
# Load the models, and clear their device association settings
graph1 = tf.train.import_meta_graph(GRAPH1_PATH, clear_devices=True, input_map={g1_input: input_data})
graph2 = tf.train.import_meta_graph(GRAPH2_PATH, clear_devices=True, input_map={g2_input: "g1_output:0"})

When you inspect the resulting graph with Tensorboard, you will see that the subgraphs are now connected. Notice the clear_devices argument, it comes handy if you transfer the subgraph models to another machine with different hardware setting, e.g. more or less GPUs.

Now, as we have connected both graphs structurally, we need to restore the pre-trained variable weights. First of all, this has to happen inside a session, so i will assume all of the following code to appear in a session block or something. Since we won’t restore all variables (we didn’t save all variables on the first hand), we have to call a standard tf.global_variables_initializer() at first. After that, we can restore our saved weights. Other tutorials on how to save and restore models suggest that we should restore the weights with the both saver objects graph1 and graph2 which were implicit created when we imported the meta graph. This doesn’t work in our case, since a saver object tries to restore all variables in the current graph, when called. But each saved model doesn’t contain all variables - graph1 doesn’t contain the weights for graph2 and vice versa. So we need to restrict the scope in that our saver should restore the weights. Unfortunately, the tf.train.import_meta_graph constructor provides no means to do this - so we have to define two more savers, just for restoring variable weights in a certain scope. The code looks similar to the saving procedure earlier:

graph1_restorer = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='graph1_MY_TRAINABLE_VARIABLE_SCOPE'))
graph1_restorer.restore(sess, save_path=graph1_SOME_PATH)
graph2_restorer = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='graph2_MY_TRAINABLE_VARIABLE_SCOPE'))
graph2_restorer.restore(sess, save_path=graph2_SOME_PATH)

Please notice that each saver (or restorer) restores just the scope of the trainable variables, this scope is exactly the scope we defined in the savers of the subgraphs earlier.

Since collections are no part of a var_list, each model still has all collections available, and stored functions or constants can easily be restored out of each collection - but if you didn’t think of using different names for the collection in the different subgraphs, you will have a problem now.

But i hope this post will help you a little when combining different subgraphs to a supergraph, and will save you hours of browsing StackOverflow. Have fun.