0

For my project I am using pytorch as a linear algebra backend. For the performance part of my code, I need to do 1D convolutions of 2 small (length between 2 and 9) vectors (1D tensors) a very large number of times. My code allows for batch-processing of inputs and thus I can stack a couple of input vectors to create matrices that can then be convolved all at the same time. Since torch.conv1d does not allow for convolving along a single dimension for 2D inputs, I had to write my own convolution function called convolve. This new function however consists of a double for-loop and is therefore very very slow.

Question: how can I make the convolve function perform faster through better code-design and let it be able to deal with batched inputs (=2D tensors)?

Partial answer: somehow avoid the double for-loop

Below are three jupyter notebook cells that recreate a minimal example. Note that the you need line_profiler and the %%writefile magic command to make this work!

%%writefile SO_CONVOLVE_QUESTION.py
import torch

def conv1d(a, v):
    padding = v.shape[-1] - 1
    return torch.conv1d(
        input=a.view(1, 1, -1), weight=v.flip(0).view(1, 1, -1), padding=padding, stride=1
    ).squeeze()

def convolve(a, v):
    if a.ndim == 1:
        a = a.view(1, -1)
        v = v.view(1, -1) 

    nrows, vcols = v.shape
    acols = a.shape[1]

    expanded = a.view((nrows, acols, 1)) * v.view((nrows, 1, vcols))
    noutdim = vcols + acols - 1
    out = torch.zeros((nrows, noutdim))
    for i in range(acols):  
        for j in range(vcols):
            out[:, i+j] += expanded[:, i, j]  
    return out.squeeze()
    
x = torch.randn(5)
y = torch.randn(7)

I write the code to the SO_CONVOLVE_QUESTION.py because that is necessary for line_profiler and to use as a setup for timeit.timeit.

Now we can evaluate the output and performance of the code above on non-batch input (x, y) and batched input (x_batch, y_batch):

from SO_CONVOLVE_QUESTION import *
# Without batch processing
res1 = conv1d(x, y)
res = convolve(x, y)
print(torch.allclose(res1, res)) # True

# With batch processing, NB first dimension!
x_batch = torch.randn(5, 5)
y_batch = torch.randn(5, 7)

results = []
for i in range(5):
    results.append(conv1d(x_batch[i, :], y_batch[i, :]))
res1 = torch.stack(results)
res = convolve(x_batch, y_batch)
print(torch.allclose(res1, res))  # True

print(timeit.timeit('convolve(x, y)', setup=setup, number=10000)) # 4.83391789999996
print(timeit.timeit('conv1d(x, y)', setup=setup, number=10000))   # 0.2799923000000035

In the block above you can see that performing convolution 5 times using the conv1d function produces the same result as convolve on batched inputs. We can also see that convolve (= 4.8s) is much slower than the conv1d (= 0.28s). Below we assess the slow part of the convolve function WITHOUT batch processing using line_profiler:

%load_ext line_profiler
%lprun -f convolve convolve(x, y)  # evaluated without batch-processing!

Output:

Timer unit: 1e-07 s

Total time: 0.0010383 s
File: C:\python_projects\pysumo\SO_CONVOLVE_QUESTION.py
Function: convolve at line 9

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     9                                           def convolve(a, v):
    10         1         68.0     68.0      0.7      if a.ndim == 1:
    11         1        271.0    271.0      2.6          a = a.view(1, -1)
    12         1         44.0     44.0      0.4          v = v.view(1, -1) 
    13                                           
    14         1         28.0     28.0      0.3      nrows, vcols = v.shape
    15         1         12.0     12.0      0.1      acols = a.shape[1]
    16                                           
    17         1       4337.0   4337.0     41.8      expanded = a.view((nrows, acols, 1)) * v.view((nrows, 1, vcols))
    18         1         12.0     12.0      0.1      noutdim = vcols + acols - 1
    19         1        127.0    127.0      1.2      out = torch.zeros((nrows, noutdim))
    20         6         32.0      5.3      0.3      for i in range(acols):  
    21        40        209.0      5.2      2.0          for j in range(vcols):
    22        35       5194.0    148.4     50.0              out[:, i+j] += expanded[:, i, j]  
    23         1         49.0     49.0      0.5      return out.squeeze()

Obviously a double for-loop and the line creating the expanded tensor are the slowest. Are these parts avoidable with better code-design?

2 Answers 2

2

Pytorch has a batch analyzing tool called torch.nn.functional and there you have a conv1d function (obviously 2d as well and much much more). we will use conv1d.

Suppose you want to convolve 100 vectors given in v1 with 1 another vector given in v2. v1 has dimension of (minibatch , in channels , weights) and you need 1 channel by default. In addition, v2 has dimensions of * (\text{out_channels} , (out_channels,groups / in_channels,kW)*. You are using 1 channel and therefore 1 group so v1 and v2 will be given by:

import torch
from torch.nn import functional as F

num_vectors = 100
len_vectors = 9
v1 = torch.rand((num_vectors, 1, len_vectors))
v2 = torch.rand(1, 1, 6)

now we can simply compute the necessary padding via

padding = torch.min(torch.tensor([v1.shape[-1], v2.shape[-1]])).item() - 1

and the convolution can be done using

conv_result = temp = F.conv1d(v1, v2, padding=padding)

I did not time it but it should be considerably faster than your initial double for loop.

Sign up to request clarification or add additional context in comments.

Comments

0

Turns out that there is a way to do it without for-loops via grouping of the inputs along a dimension:

out = torch.conv1d(x_batch.unsqueeze(0), y_batch.unsqueeze(1).flip(2), padding=y_batch.size(1)-1, groups=x_batch.size(0))
print(torch.allclose(out, res1))  # True

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.