Fast Matryoshka Representation Learning
Matryoshka Representation Learning (MRL) Overview
In Matryoshka Representation Learning, the main goal is design a way to learn representations at different granularities for downstream tasks. The flexibility allows downstream applications to adjust the granularity depending on the hardware and performance requirements.
For example, many recommendation systems adopt a twostage approach of shortlisting and reranking. For shortlisting, using a low granularity representation could be beneficial since the reduced dimensionality means lower FLOPS and faster computations. For reranking, speed is not as important since it operates on the much shorter list, and using a high granularity representation can improve the precision and recall.
With MRL, the learned representations can be sliced into smaller chunks; the chunks can be used alone to represent the input. We first train a neural network to encode images into 2048dimensional vectors, then with MRL, we can use the first half of the vectors (i.e., 1024dimensional) to represent the same images at inference time.
How is this possible?
Matryoshka Linear Layer
The authors provide a very clean implementation of MRL in the official codebase on Github. I found reading through the codebase to be the easiest way to understand the secret sauce behind MRL.
The paper is also very wellwritten and contains detailed experiments and discussions, such as adaptive classification and GradCAM visualization; I definitely recommend reading. For our purposes, we will focus on the code. The code snippets below are modified from the official source code to improve readability.
class MRL_Linear_Layer(nn.Module):
# ResNet50 example: the original feature dimension is 2048, so
# nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
def __init__(self, nesting_list, num_classes):
super().__init__()
self.nesting_list = nesting_list
self.num_classes = num_classes
self.nesting_classifiers = []
# create independent nn.Linear for each nesting level
for i, num_feat in enumerate(self.nesting_list):
classifier = nn.Linear(num_feat, self.num_classes)
self.add_module(f"classifier_{i}", classifier)
self.nesting_classifiers.append(classifier)
def forward(self, x):
nesting_logits = ()
# forward pass and store all predictions
for num_feat, classifier in zip(self.nesting_list, self.nesting_classifiers):
pred_logits = classifier.forward(x[:, :num_feat])
nesting_logits += (pred_logits,)
return nesting_logits
Note that MRL_Linear_Layer
is meant to replace the last layer of ResNet50, hence the linear modules are referred to as classifiers.
MRL creates M independent linear modules each with a different input size, where M is the size of the nesting_list
.
During the forward pass, the input tensor is sliced to the appropriate sizes and fed into the respective classifiers. The output predictions are saved separately, and the loss is calculated by comparing the predictions to the target separately, and then summed together. The loss calculation is analogous to that of specialization mixtureofexperts mentioned in one of Hintonâ€™s lecture.
Fast Matryoshka Linear Layer
The original implemention of Matryoshka Linear Layer (and the snippet above) is meant to be easy to understand. But it is possible to remove the loop to improve training speed. The idea is to use a similar approach to fast mixtureofexperts mentioned in an earlier post.
Intuitively, the different classifiers becomes the experts and a mask is used to simulate the slicing operation.
class MRL_Linear_Layer_Fast(nn.Module):
def __init__(self, nesting_list, num_classes):
super().__init__()
self.nesting_list = nesting_list
self.num_classes = num_classes
num_nestings = len(nesting_list)
rep_size = max(nesting_list)
# `num_nesting` becomes the equivalent of `num_experts` in FastMoE
self._weights = nn.Parameter(torch.empty(num_nestings, rep_size, num_classes))
self._biases = nn.Parameter(torch.empty(num_nestings, 1, num_classes))
self.register_buffer("_mask", torch.zeros(num_nestings, rep_size, num_classes))
nn.init.orthogonal_(self._weights, gain=1.0)
nn.init.orthogonal_(self._biases, gain=1.0)
# initialize mask to simulate the slicing operation
# write 1s in each row where the weights will be used
for i, num_feat in enumerate(self.nesting_list):
self._mask[i, :num_feat, :] = 1
def copy_parameters_(self, model: MRL_Linear_Layer):
# helper function to copy weights from `MRL_Linear_Layer`
# for comparison purposes only, not needed for training
assert len(model.nesting_list) == len(self.nesting_list)
for i, classifier in enumerate(model.nesting_classifiers):
num_feat = model.nesting_list[i]
self._weights.data[i, :num_feat, :] = classifier.weight.data.t()
self._biases.data[i, :, :] = classifier.bias.data
def forward(self, x):
# mask out portions of weight matrix to simulate slicing
# then loop can be replaced with basic torch operators
masked_weights = self._weights.mul(self._mask)
return x.matmul(masked_weights).add(self._biases)
Checking Equivalence
We verify that the two implementations are equivalent by making sure that, when given the same inputs under the same initialized weights, the two models produce the same outputs, losses, and gradients.
class Matryoshka_CE_Loss(nn.Module):
def __init__(self, relative_importance=None):
super().__init__()
self.criterion = nn.CrossEntropyLoss()
self.relative_importance = relative_importance
def forward(self, output, target):
# output shape: [G granularities, N batch size, C num. classes]
# target shape: [N batch size]
# Calculate losses for each output and stack them
nesting_losses = [self.criterion(out_i, target) for out_i in output]
losses = torch.stack(nesting_losses)
# Set relative_importance to 1 if not specified
rel_importance = (
torch.ones_like(losses)
if self.relative_importance is None
else torch.tensor(self.relative_importance)
)
# Apply relative importance weights
weighted_losses = rel_importance * losses
return weighted_losses.sum()
def check_model_equivalence(device, copy):
device = torch.device(device)
num_classes = 1000
nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024]
slow_model = MRL_Linear_Layer(nesting_list, num_classes).to(device)
fast_model = MRL_Linear_Layer_Fast(nesting_list, num_classes).to(device)
if copy:
# use same weights as slow model
fast_model.copy_parameters_(slow_model)
criterion = Matryoshka_CE_Loss().to(device)
# check equivalence
# run the forwardbackwardupdate twice
# second time is to make sure the gradients are the same
for iter in range(2):
batch_size = 64
data = torch.randn(batch_size, nesting_list[1]).to(device)
target = torch.randint(num_classes, (batch_size,)).to(device)
pred_stacked_logits = slow_model(data)
pred_logits_slow = torch.stack(pred_stacked_logits)
loss_slow = criterion(pred_stacked_logits, target)
pred_logits_fast = fast_model(data)
loss_fast = criterion(pred_logits_fast, target)
outputs_equal = torch.allclose(pred_logits_slow, pred_logits_fast, atol=1e5)
losses_equal = torch.allclose(loss_slow, loss_fast, atol=1e5)
print(f"({iter=}) outputs {'are' if outputs_equal else 'are NOT'} equal")
print(f"({iter=}) losses {'are' if losses_equal else 'are NOT'} equal")
loss_slow.backward()
loss_fast.backward()
torch.optim.SGD(slow_model.parameters(), lr=0.01).step()
torch.optim.SGD(fast_model.parameters(), lr=0.01).step()
# benchmark speed
# simulate training iterations
repeats = 100
def benckmark_forward_backward(model):
for _ in range(repeats):
criterion(model(data), target).backward()
torch.optim.SGD(model.parameters(), lr=0.01).step()
from functools import partial
benchmark_slow = partial(benckmark_forward_backward, slow_model)
benckmark_fast = partial(benckmark_forward_backward, fast_model)
execution_time = timeit.timeit(benchmark_slow, number=1)
print(f"(slow model): {execution_time/repeats:.5f} seconds ({repeats} passes)")
execution_time = timeit.timeit(benckmark_fast, number=1)
print(f"(fast model): {execution_time/repeats:.5f} seconds ({repeats} passes)")
The results on my desktop machine with RTX 3090 is below.

First run using the CUDA backend, we verify that the models are equivalent. The fast implementation takes only about 57% of the time it took the slow model.

For the second run, we set
copy=False
and verify that the results are different. Again, the fast model took about 55% of the time of the slow model, consistent with the previous run. 
Third run uses the CPU backend. This time the fast model is a lot slower in comparison. This is expected since the fast model performs more floating point calculations. The tradeoff is only worth it when the backend can parallized these computations.
Running on CUDA

(iter=0) outputs are equal
(iter=0) losses are equal
(iter=1) outputs are equal
(iter=1) losses are equal
(slow model): 0.00208 seconds (100 passes)
(fast model): 0.00118 seconds (100 passes)
Running on CUDA (no copy)

(iter=0) outputs are NOT equal
(iter=0) losses are NOT equal
(iter=1) outputs are NOT equal
(iter=1) losses are NOT equal
(slow model): 0.00212 seconds (100 passes)
(fast model): 0.00117 seconds (100 passes)
Running on CPU

(iter=0) outputs are equal
(iter=0) losses are equal
(iter=1) outputs are equal
(iter=1) losses are equal
(slow model): 0.00982 seconds (100 passes)
(fast model): 0.07489 seconds (100 passes)