I use in my project large graphs that can take more than 5min of compilation time. I decorated the large intermediate parts using a tf.function(reduce_retracing=True). I also heavily use tf.experimental.ExtensionType instances.
My question is, would it be possible to cache the graphs compiled by these functions, and restore those graphs on a different execution? In the case where multiple concrete functions have been retraced, i would like to save them all.
I precise that i am not interested in model saving, since it saves weights not graphs, and needs recompiling anyways.
Moreover, i have tried the following code:
import tensorflow as tf
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
print('retrace')
return x + 1.
f(tf.constant(0.)) # prints 'retrace'
exported = tf.train.Checkpoint(f=f)
tf.saved_model.save(exported, 'tf_cache/')
imported = tf.saved_model.load('tf_cache/') # prints 'retrace'
imported.f(tf.constant(0.))
The issue is that i get a retrace when loading the saved checkpoint. Moreover, when f uses tf.experimental.ExtensionType inputs/outputs, tf.save_model.save raises a ValueError.
Thanks in advance for any answer.