0

As far as I know, JAX only supports "rank 1" vector-valued function for the jax.jacrev autograd. How do I support higher rank tensors?

I don't want to flatten my matrix, then unflatten it inside the function. I don't like this ugly way. It's not very intuitive and representative of the essence of the operations.

So far, I tried implementing a FFN:

    @jax.jit
    def forward_train(self, input_vector, weights, biases):
        current_vector = input_vector
        for i in range(self.num_layers): #constant
            current_vector = jnp.clip(weights[i] @ current_vector + biases[i], 0, None)
        return current_vector

    @jax.jit
    def grads(self, input_vector, weights, biases, dC_dY):
        dY_dX, dY_dW, dY_dB = jax.jacrev(self.forward_train, [0, 1, 2])(input_vector, weights, biases)
        dC_dX = dC_dY @ dY_dX
        dC_dW = [dC_dY @ dY_dW_i for dY_dW_i in dY_dW] # <- Is this correct???
        dC_dB = [dC_dY @ dY_dB_i for dY_dB_i in dY_dB]
        return [dC_dX, dC_dW, dC_dB]
2
  • 2
    jax.jacrev supports Jacobians of arbitrary-rank tensors. Have you tried computing the Jacobian directly, without flattening? Commented Jul 29 at 23:24
  • @jakevdp if you decide to post your comment as an answer, please let me know and I would remove mine. Commented Aug 5 at 14:35

1 Answer 1

1

You do not need to flatten, jacrev computes the Jacobian matrix for higher ranks, e.g:

import jax
import jax.numpy as jnp

def f(t: float, n: int = 5) -> jnp.ndarray:
    return jnp.array([t**i for i in range(n)])
print(jax.jacrev(f)(jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])))
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.