[2024-06-15] runtime patching of a nn.Module
so i was trying to patch a torch Module to change its behavior on forward()
my requirements were:
- i want to be able to create a patched object from an existing instance of the parent module
- i want to preserve the parent object's behavior aside from forward (in particular, the patched object should still return True for isinstance() of the parent class)
- i don't want to create a new object, to avoid expensive copies of the parent module's parameters
- i want to be reasonably invariant to the implementation of the parent module
- i want to be able to tell that whether an object has been patched
- i want to add additional attributes / parameters to the module
- i don't want my ide to complain
subclassing the parent object and overriding the init and forward methods is one option, but this runs into issues with requirements (1) and (4), since I would need to write some sort of class method to initialize a patched object and transfer all of the parent object's existing attributes over, which adds a dependency on exactly what parameters the parent object has. and of course doing a simple deepcopy violates (3)
something closer to satisfying the reqs would be using MethodType to bind a newly defined forward function to the existing object's forward, along with adding an additional attribute at runtime like _is_patched. but, checking _is_patched with hasattr is a bit messy, and some type checkers would complain when the newly added parameters are accessed, since they still think that the object is of the parent type, violating (7).
a possible fix could be to use Protocols, specifically the inheritance syntax shown here. unfortunately, protocols inheriting from non-protocols is not actually supported in python at the time of writing, and also, who the heck knows what a protocol is anyway?
the method i ended up settling with does satisfy all seven requirements, but is a bit cursed: i realized that i could just directly override the __class__ attribute of an existing object, which uhh, is probably fine if you're overriding it with a subclass, i think... as a plus, it looks like python figures out the method overrides automatically, so no mucking around with MethodTypes either.
here's a simple example of this method in action. we're patching an existing nn.Embedding object to finetune the embeddings of only a few specific tokens:
import torch import torch.nn as nn import torch.nn.functional as F class MostlyFrozenEmbedding(nn.Embedding): @classmethod def from_existing( cls, embedding: nn.Embedding, trainable_token_ids: list[int], ) -> "MostlyFrozenEmbedding": with torch.no_grad(): for param in embedding.parameters(): param.requires_grad_(False) embedding.delta = nn.Parameter( torch.zeros(len(trainable_token_ids), embedding.embedding_dim) ) embedding.register_buffer( "trainable_token_ids", torch.tensor(trainable_token_ids) .reshape(-1, 1) .repeat(1, embedding.embedding_dim), ) embedding.__class__ = cls return embedding def forward(self, input: torch.Tensor): updated_weight = torch.scatter_add( self.weight, dim=0, index=self.trainable_token_ids, src=self.delta ) # taken from nn.Embedding forward implementation return F.embedding( input, updated_weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, )
the patched object is created by simply calling the from_existing method:
original = nn.Embedding(16, 32) patched = MostlyFrozenEmbedding.from_existing( original, trainable_token_ids=[1, 2] )
the patched object works as expected with torch.compile too.