2

I wrote a simple wrapper to add special methods to a given PyTorch neural network. While the implementation below works well for general objects like strings, lists etc. I get a RecursionError when applying it to a torch.nn.Module. It seems that in the latter case the call to self.instance inside the __getattr__ method is unsuccessful, so it falls back to __getattr__ again, leading to the infinite loop (I also tried self.__dict__['instane'] without luck).

I assume that this behaviour stems from the implementations of the __getattr__ and __setattr__ methods torch.nn.Module but after inspecting their implementations I still don't see how.

I would like to understand in detail what is going on and how to fix the error in my implementation.

(I am aware of the similar question in link but it does not answer my question.)

Here is a minimal implementation to recreate the my situation.

import torch

class MyWrapper(torch.nn.Module):
    def __init__(self, instance):
        super().__init__()
        self.instance = instance

    def __getattr__(self, name):
        print("trace", name)
        return getattr(self.instance, name)

# Working example
obj = "test string"
obj_wrapped = MyWrapper(obj)
print(obj_wrapped.split(" ")) # trace split\n ['test', 'string']

# Failing example
net = torch.nn.Linear(12, 12)
net.test_attribute = "hello world"
b = MyWrapper(net)

print(b.test_attribute) # RecursionError: maximum recursion depth exceeded
b.instance # RecursionError: maximum recursion depth exceeded

1 Answer 1

1

The error has nothing much to do with torch.nn.Module (or any superclass/subclass of it for that matter). It's due to how attribute look-up works in Python classes.

As you've overridden the __getattr__ special method in MyWrapper class, when you do self.instance inside __getattr__, it's getting into an infinite recursive situation to get the attribute named instance as it's looking into the __getattr__ of the current object's (self) class (MyWrapper) again (and again) and failing.

Fix:

You can take help from the fact that Python allows you to use superclass's __getattr__ method (easily accessible using the super method). So if we use superclass's __getattr__ to get the instance resolution correctly, then we can still use getattr to get the next name lookup. For example:


    In [259]: class MyWrapper(torch.nn.Module):
         ...:     def __init__(self, instance):
         ...:         super().__init__()
         ...:         self.instance = instance
         ...: 
         ...:     def __getattr__(self, name):
         ...:         instance = super().__getattr__("instance")
         ...:         return getattr(instance, name)
         ...:         
    
    In [260]: # Your failing example - now working
         ...: net = torch.nn.Linear(12, 12)
         ...: net.test_attribute = "hello world"
         ...: b = MyWrapper(net)
    
    In [261]: print(b.test_attribute)
    hello world

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.