3

In my TF model, my call functions calls an external energy function which is dependent on a function where single parameter is passed twice (see simplified version below):

import tensorflow as tf

@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
    E3 = 2.0
    return E3

@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
    E3 = 2.0
    return E3

@tf.function # without tf.function this works fine
def energy(coords, gamma):
    xyz_i = coords[0, 0 : 3]
    xyz_j = coords[0, 3 : 6]
    rij = xyz_j - xyz_i
    norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
    E3 = calc_sw3( gamma,gamma,norm_rij)    # repeating gamma gives error
    # E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
    return E3



class SWLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.gamma = tf.Variable(2.51412, dtype=tf.float32)

    def call(self, coords_all):
        total_conf_energy = energy( coords_all, self.gamma)
        return total_conf_energy
# =============================================================================


SWL = SWLayer()
coords2 = tf.constant([[
                        1.9434,  1.0817,  1.0803,  
                        2.6852,  2.7203,  1.0802,  
                        1.3807,  1.3573,  1.3307]])

with tf.GradientTape() as tape:
    tape.watch(coords2)
    E = SWL( coords2)

Here if gamma is passed only once, or if I do not use tf.function decorator. But with tf.function and passing same variable twice, I get the following error:

Traceback (most recent call last):
  File "temp_tf.py", line 47, in <module>
    E = SWL( coords2)
  File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "temp_tf.py", line 34, in call
    total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).

in user code:

    File "temp_tf.py", line 22, in energy  *
        E3 = calc_sw3( gamma,gamma,norm_rij)    # repeating gamma gives error

    IndexError: list index out of range


Call arguments received:
  • coords_all=tf.Tensor(shape=(1, 9), dtype=float32)

Is this expected behaviour?

1 Answer 1

2

Interesting question! I think the error originates from retracing, which causes the tf.function to evaluate the python snippets in energy more than once. See this issue. Also, this could be related to a bug.

A couple observations:

1. Removing the tf.function decorator from calc_sw3 works and is consistent with the docs:

[...] tf.function applies to a function and all other functions it calls.

So if you apply tf.function explicitly to calc_sw3 again, you may trigger a retracing, but then you may wonder why calc_sw3_noerr works? That is, it must have something to do with the variable gamma.

2. Adding input signatures to the tf.function above the energy function, while leaving the rest of the code the way it is, also works:

@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
def energy(coords, gamma):
    xyz_i = coords[0, 0 : 3]
    xyz_j = coords[0, 3 : 6]
    rij = xyz_j - xyz_i
    norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5

    E3 = calc_sw3(gamma, gamma, norm_rij) 
    return E3

This method:

[...] ensures only one ConcreteFunction is created, and restricts the GenericFunction to the specified shapes and types. It is an effective way to limit retracing when Tensors have dynamic shapes.

So perhaps it is assumed that gamma is called with a different shape each time, thus triggering retracing (just an assumption). The fact that an error is triggered is then actually intentional or deliberately designed as stated here. Also another interesting comment:

tf.functions can only handle a pre defined input shape, if the shape changes, or if different python objects get passed, tensorflow automagically rebuilds the function

Finally, why do I think it is a tracing problem? Because the actual error is coming from this part of the code snippet:

xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5

which you can confirm by commenting it out and replacing norm_rij with some value and then calling calc_sw3. It will work. This means that this code snippet is probably executed more than once, maybe due to the reasons mentioned above. This is also well documented here:

In the first stage, referred to as "tracing", Function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.

In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage

Sign up to request clarification or add additional context in comments.

8 Comments

Thanks for the detailed reply, following your reply I just noticed that this occurs only when gamma is a tf.Variable, if I make it tf.constant it works fine (may be this helps in pin pointing the issue?). For my work, workarounds suggested by you shall be enough, though I was curious, if tf.function re-evaluates the function, shall it not re-evaluate it with same parameters? if yes then why the index error, and which component is exactly giving it?
That is also an interesting observation regarding tf.constant. It is hard to say actually but I am assuming that norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5 is causing the indexing error. The behaviour is admittedly really tricky. Glad it helped though.
tf.constant came to my mind because the link you provided said that each tf.Variable object is assigned an ID. So I assumed that might create clashes if calc_sw3 function has two inputs with different ids, but gets passed identical ID Variable object. But I am relatively novice in TF to still going through docs. I was using tf.function as substitute for jit.script of PyTorch.
I have also filed bug report on github, will update once I get something there : github.com/tensorflow/tensorflow/issues/53494
Ok. Maybe also edit it with your new insights.
|

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.