In my TF model, my call functions calls an external energy function which is dependent on a function where single parameter is passed twice (see simplified version below):
import tensorflow as tf
@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
E3 = 2.0
return E3
@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
E3 = 2.0
return E3
@tf.function # without tf.function this works fine
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
# E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
return E3
class SWLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.gamma = tf.Variable(2.51412, dtype=tf.float32)
def call(self, coords_all):
total_conf_energy = energy( coords_all, self.gamma)
return total_conf_energy
# =============================================================================
SWL = SWLayer()
coords2 = tf.constant([[
1.9434, 1.0817, 1.0803,
2.6852, 2.7203, 1.0802,
1.3807, 1.3573, 1.3307]])
with tf.GradientTape() as tape:
tape.watch(coords2)
E = SWL( coords2)
Here if gamma is passed only once, or if I do not use tf.function decorator. But with tf.function and passing same variable twice, I get the following error:
Traceback (most recent call last):
File "temp_tf.py", line 47, in <module>
E = SWL( coords2)
File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "temp_tf.py", line 34, in call
total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).
in user code:
File "temp_tf.py", line 22, in energy *
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
IndexError: list index out of range
Call arguments received:
• coords_all=tf.Tensor(shape=(1, 9), dtype=float32)
Is this expected behaviour?