Hung Yu Ling

Fast Matryoshka Representation Learning

Matryoshka Representation Learning (MRL) Overview

In Matryoshka Representation Learning, the main goal is design a way to learn representations at different granularities for downstream tasks. The flexibility allows downstream applications to adjust the granularity depending on the hardware and performance requirements.

For example, many recommendation systems adopt a two-stage approach of shortlisting and re-ranking. For shortlisting, using a low granularity representation could be beneficial since the reduced dimensionality means lower FLOPS and faster computations. For re-ranking, speed is not as important since it operates on the much shorter list, and using a high granularity representation can improve the precision and recall.

With MRL, the learned representations can be sliced into smaller chunks; the chunks can be used alone to represent the input. We first train a neural network to encode images into 2048-dimensional vectors, then with MRL, we can use the first half of the vectors (i.e., 1024-dimensional) to represent the same images at inference time.

How is this possible?

Matryoshka Linear Layer

The authors provide a very clean implementation of MRL in the official codebase on Github. I found reading through the codebase to be the easiest way to understand the secret sauce behind MRL.

The paper is also very well-written and contains detailed experiments and discussions, such as adaptive classification and Grad-CAM visualization; I definitely recommend reading. For our purposes, we will focus on the code. The code snippets below are modified from the official source code to improve readability.

class MRL_Linear_Layer(nn.Module):

    # ResNet50 example: the original feature dimension is 2048, so
    # nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]

    def __init__(self, nesting_list, num_classes):
        self.nesting_list = nesting_list
        self.num_classes = num_classes

        self.nesting_classifiers = []

        # create independent nn.Linear for each nesting level

        for i, num_feat in enumerate(self.nesting_list):
            classifier = nn.Linear(num_feat, self.num_classes)
            self.add_module(f"classifier_{i}", classifier)

    def forward(self, x):
        nesting_logits = ()

        # forward pass and store all predictions

        for num_feat, classifier in zip(self.nesting_list, self.nesting_classifiers):
            pred_logits = classifier.forward(x[:, :num_feat])
            nesting_logits += (pred_logits,)

        return nesting_logits

Note that MRL_Linear_Layer is meant to replace the last layer of ResNet50, hence the linear modules are referred to as classifiers. MRL creates M independent linear modules each with a different input size, where M is the size of the nesting_list.

During the forward pass, the input tensor is sliced to the appropriate sizes and fed into the respective classifiers. The output predictions are saved separately, and the loss is calculated by comparing the predictions to the target separately, and then summed together. The loss calculation is analogous to that of specialization mixture-of-experts mentioned in one of Hinton’s lecture.

Fast Matryoshka Linear Layer

The original implemention of Matryoshka Linear Layer (and the snippet above) is meant to be easy to understand. But it is possible to remove the loop to improve training speed. The idea is to use a similar approach to fast mixture-of-experts mentioned in an earlier post.

Intuitively, the different classifiers becomes the experts and a mask is used to simulate the slicing operation.

class MRL_Linear_Layer_Fast(nn.Module):
    def __init__(self, nesting_list, num_classes):
        self.nesting_list = nesting_list
        self.num_classes = num_classes

        num_nestings = len(nesting_list)
        rep_size = max(nesting_list)

        # `num_nesting` becomes the equivalent of `num_experts` in Fast-MoE

        self._weights = nn.Parameter(torch.empty(num_nestings, rep_size, num_classes))
        self._biases = nn.Parameter(torch.empty(num_nestings, 1, num_classes))
        self.register_buffer("_mask", torch.zeros(num_nestings, rep_size, num_classes))

        nn.init.orthogonal_(self._weights, gain=1.0)
        nn.init.orthogonal_(self._biases, gain=1.0)

        # initialize mask to simulate the slicing operation
        # write 1s in each row where the weights will be used

        for i, num_feat in enumerate(self.nesting_list):
            self._mask[i, :num_feat, :] = 1

    def copy_parameters_(self, model: MRL_Linear_Layer):
        # helper function to copy weights from `MRL_Linear_Layer`
        # for comparison purposes only, not needed for training

        assert len(model.nesting_list) == len(self.nesting_list)
        for i, classifier in enumerate(model.nesting_classifiers):
            num_feat = model.nesting_list[i]
  [i, :num_feat, :] =
  [i, :, :] =

    def forward(self, x):
        # mask out portions of weight matrix to simulate slicing
        # then loop can be replaced with basic torch operators

        masked_weights = self._weights.mul(self._mask)
        return x.matmul(masked_weights).add(self._biases)

Checking Equivalence

We verify that the two implementations are equivalent by making sure that, when given the same inputs under the same initialized weights, the two models produce the same outputs, losses, and gradients.

class Matryoshka_CE_Loss(nn.Module):
    def __init__(self, relative_importance=None):
        self.criterion = nn.CrossEntropyLoss()
        self.relative_importance = relative_importance

    def forward(self, output, target):
        # output shape: [G granularities, N batch size, C num. classes]
        # target shape: [N batch size]

        # Calculate losses for each output and stack them
        nesting_losses = [self.criterion(out_i, target) for out_i in output]
        losses = torch.stack(nesting_losses)

        # Set relative_importance to 1 if not specified
        rel_importance = (
            if self.relative_importance is None
            else torch.tensor(self.relative_importance)

        # Apply relative importance weights
        weighted_losses = rel_importance * losses
        return weighted_losses.sum()

def check_model_equivalence(device, copy):
    device = torch.device(device)

    num_classes = 1000
    nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024]
    slow_model = MRL_Linear_Layer(nesting_list, num_classes).to(device)
    fast_model = MRL_Linear_Layer_Fast(nesting_list, num_classes).to(device)

    if copy:
        # use same weights as slow model

    criterion = Matryoshka_CE_Loss().to(device)

    # check equivalence
    # run the forward-backward-update twice
    # second time is to make sure the gradients are the same

    for iter in range(2):
        batch_size = 64
        data = torch.randn(batch_size, nesting_list[-1]).to(device)
        target = torch.randint(num_classes, (batch_size,)).to(device)

        pred_stacked_logits = slow_model(data)
        pred_logits_slow = torch.stack(pred_stacked_logits)
        loss_slow = criterion(pred_stacked_logits, target)

        pred_logits_fast = fast_model(data)
        loss_fast = criterion(pred_logits_fast, target)

        outputs_equal = torch.allclose(pred_logits_slow, pred_logits_fast, atol=1e-5)
        losses_equal = torch.allclose(loss_slow, loss_fast, atol=1e-5)
        print(f"({iter=}) outputs {'are' if outputs_equal else 'are NOT'} equal")
        print(f"({iter=}) losses {'are' if losses_equal else 'are NOT'} equal")

        torch.optim.SGD(slow_model.parameters(), lr=0.01).step()
        torch.optim.SGD(fast_model.parameters(), lr=0.01).step()

    # benchmark speed
    # simulate training iterations

    repeats = 100

    def benckmark_forward_backward(model):
        for _ in range(repeats):
            criterion(model(data), target).backward()
            torch.optim.SGD(model.parameters(), lr=0.01).step()

    from functools import partial

    benchmark_slow = partial(benckmark_forward_backward, slow_model)
    benckmark_fast = partial(benckmark_forward_backward, fast_model)

    execution_time = timeit.timeit(benchmark_slow, number=1)
    print(f"(slow model): {execution_time/repeats:.5f} seconds ({repeats} passes)")

    execution_time = timeit.timeit(benckmark_fast, number=1)
    print(f"(fast model): {execution_time/repeats:.5f} seconds ({repeats} passes)")

The results on my desktop machine with RTX 3090 is below.

Running on CUDA
(iter=0) outputs are equal
(iter=0) losses are equal
(iter=1) outputs are equal
(iter=1) losses are equal
(slow model): 0.00208 seconds (100 passes)
(fast model): 0.00118 seconds (100 passes)

Running on CUDA (no copy)
(iter=0) outputs are NOT equal
(iter=0) losses are NOT equal
(iter=1) outputs are NOT equal
(iter=1) losses are NOT equal
(slow model): 0.00212 seconds (100 passes)
(fast model): 0.00117 seconds (100 passes)

Running on CPU
(iter=0) outputs are equal
(iter=0) losses are equal
(iter=1) outputs are equal
(iter=1) losses are equal
(slow model): 0.00982 seconds (100 passes)
(fast model): 0.07489 seconds (100 passes)