[back to home]

[2024-06-15] runtime patching of a nn.Module

so i was trying to patch a torch Module to change its behavior on forward()

my requirements were:

  1. i want to be able to create a patched object from an existing instance of the parent module
  2. i want to preserve the parent object's behavior aside from forward (in particular, the patched object should still return True for isinstance() of the parent class)
  3. i don't want to create a new object, to avoid expensive copies of the parent module's parameters
  4. i want to be reasonably invariant to the implementation of the parent module
  5. i want to be able to tell that whether an object has been patched
  6. i want to add additional attributes / parameters to the module
  7. i don't want my ide to complain

subclassing the parent object and overriding the init and forward methods is one option, but this runs into issues with requirements (1) and (4), since I would need to write some sort of class method to initialize a patched object and transfer all of the parent object's existing attributes over, which adds a dependency on exactly what parameters the parent object has. and of course doing a simple deepcopy violates (3)

something closer to satisfying the reqs would be using MethodType to bind a newly defined forward function to the existing object's forward, along with adding an additional attribute at runtime like _is_patched. but, checking _is_patched with hasattr is a bit messy, and some type checkers would complain when the newly added parameters are accessed, since they still think that the object is of the parent type, violating (7).

a possible fix could be to use Protocols, specifically the inheritance syntax shown here. unfortunately, protocols inheriting from non-protocols is not actually supported in python at the time of writing, and also, who the heck knows what a protocol is anyway?

the method i ended up settling with does satisfy all seven requirements, but is a bit cursed: i realized that i could just directly override the __class__ attribute of an existing object, which uhh, is probably fine if you're overriding it with a subclass, i think... as a plus, it looks like python figures out the method overrides automatically, so no mucking around with MethodTypes either.

here's a simple example of this method in action. we're patching an existing nn.Embedding object to finetune the embeddings of only a few specific tokens:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F


    class MostlyFrozenEmbedding(nn.Embedding):
        @classmethod
        def from_existing(
            cls,
            embedding: nn.Embedding,
            trainable_token_ids: list[int],
        ) -> "MostlyFrozenEmbedding":
            with torch.no_grad():
                for param in embedding.parameters():
                    param.requires_grad_(False)

            embedding.delta = nn.Parameter(
                torch.zeros(len(trainable_token_ids), embedding.embedding_dim)
            )
            embedding.register_buffer(
                "trainable_token_ids",
                torch.tensor(trainable_token_ids)
                .reshape(-1, 1)
                .repeat(1, embedding.embedding_dim),
            )

            embedding.__class__ = cls
            return embedding

        def forward(self, input: torch.Tensor):
            updated_weight = torch.scatter_add(
                self.weight, dim=0, index=self.trainable_token_ids, src=self.delta
            )
            # taken from nn.Embedding forward implementation
            return F.embedding(
                input,
                updated_weight,
                self.padding_idx,
                self.max_norm,
                self.norm_type,
                self.scale_grad_by_freq,
                self.sparse,
            )
    

the patched object is created by simply calling the from_existing method:

    original = nn.Embedding(16, 32)
    patched = MostlyFrozenEmbedding.from_existing(
        original, trainable_token_ids=[1, 2]
    )
    

the patched object works as expected with torch.compile too.