Adapted from my answer for Convert integer to pytorch tensor of binary bits, here's something more concise than the repo from your answer:
a = torch.tensor([[[-3, -6, -1],
[-6, -10, -1],
[ 9, 9, 6]],
[[-4, -7, -3],
[-4, -6, -1],
[14, 16, 8]],
[[-4, -6, -2],
[-6, -9, -2],
[ 9, 10, 5]]], dtype=torch.int8)
def int_to_bits(x, bits=None, dtype=torch.uint8):
assert not(x.is_floating_point() or x.is_complex()), "x isn't an integer type"
if bits is None: bits = x.element_size() * 8
mask = 2**torch.arange(bits-1,-1,-1).to(x.device, x.dtype)
return x.unsqueeze(-1).bitwise_and(mask).ne(0).to(dtype=dtype)
int_to_bits(a, dtype=torch.float32)
This returns:
tensor([[[[1., 1., 1., 1., 1., 1., 0., 1.],
[1., 1., 1., 1., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1., 0., 1., 0.],
[1., 1., 1., 1., 0., 1., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 1., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 1.],
[0., 0., 0., 0., 0., 1., 1., 0.]]],
[[[1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 1.],
[1., 1., 1., 1., 1., 1., 0., 1.]],
[[1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 1., 1., 1., 0.],
[0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 1., 1., 0.]],
[[1., 1., 1., 1., 1., 0., 1., 0.],
[1., 1., 1., 1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 0.]],
[[0., 0., 0., 0., 1., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 1., 0., 1.]]]])