Hung Yu Ling

Fast mixture-of-experts in PyTorch

Last updated on

Update 2023-01-01: added functorch.vmap comparison.
Update 2024-02-23: see also Fast Matryoshka Representation Learning

Standard Mixture-of-Experts

A mixture-of-experts (MoE) is a ensemble of neural networks, or experts, with the same input and output interfaces. A mixture-of-experts approach is a convenient way to scale up the number of parameters in a model at a small overhead cost. MoE also allows us to estimate the variance of the model prediction. Typically, the experts have the same architecture but are initialized with different weights.

We start by defining a few variables. Feel free to follow along in a Jupyter notebook or use this colab notebook.

import torch
import torch.nn as nn

torch.set_default_dtype(torch.float64)
device = "cuda" if torch.cuda.is_available() else "cpu"

num_experts = 4
expert_input_size = 60
expert_output_size = 20
hidden_size = 256
batch_size = 32

Naive MoE Implementation

We first create num_experts = 4 feed-forward neural networks in a loop using PyTorch’s nn.Sequential. This is the naive baseline implementation—since each expert is an independent nn.Module, we need to forward pass num_experts number of times every time.

We will see soon that for multilayer perceptron (MLP) models, there is an efficient and elegant way to implement mixture-of-experts.

def init_experts():
    experts = [
        nn.Sequential(
            nn.Linear(expert_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, expert_output_size),
            nn.Tanh(),
        ).to(device)
        for i in range(num_experts)
    ]
    return experts

def forward_experts(experts, data):
    return [e(data.clone()) for e in experts]

Efficient MoE Implementation

The key idea to implement MoE efficiently is to initialize the model layer-by-layer. In each layer, we need num_experts times of the parameters, e.g., weights and biases. During the forward pass, we loop through all layers to produce the results.

At first glance, we are trading a loop over num_experts for a loop over number of layers. However, since majority of time is spent computing matmul, the latter option can be parallelized more effectively.

def init_layers():
    layers = [
        (
            nn.Parameter(torch.empty(num_experts, expert_input_size, hidden_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, hidden_size).to(device)),
            torch.relu
        ),
        (
            nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, hidden_size).to(device)),
            torch.relu,
        ),
        (
            nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, hidden_size).to(device)),
            torch.relu,
        ),
        (
            nn.Parameter(torch.empty(num_experts, hidden_size, expert_output_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, expert_output_size).to(device)),
            torch.tanh,
        ),
    ]

    for index, (weight, bias, activation) in enumerate(layers):
        # Initialize each expert separately
        for w in weight:
            nn.init.orthogonal_(w, gain=1.0)
            
    return layers

def forward_layers(layers, data):
    out_layers = data.clone()
    for (weight, bias, activation) in layers:
        out_layers = activation(out_layers.matmul(weight).add(bias))
    return out_layers

Experiment 1

The compare_outputs function checks if the output of the experts between the two implementations are the same. By default, the parameters are all randomly initialized, so we expect the output to be different.

def compare_outputs(out_layers, out_experts):
    for eid, (ol, oe) in enumerate(zip(out_layers, out_experts)):
        print(f"==== Expert {eid}:", end=" ")
        with torch.no_grad():
            print("Same?", torch.isclose(ol, oe).all(), end=" | ")
            print("Error", (ol - oe).norm())

layers = init_layers()
experts = init_experts()

data = torch.randn((batch_size, expert_input_size)).to(device)

print("Before Copy")
out_layers = forward_layers(layers, data)
out_experts = forward_experts(experts, data)
compare_outputs(out_layers, out_experts)
Before Copy
==== Expert 0: Same? tensor(False, device='cuda:0') | Error tensor(4.2092, device='cuda:0')
==== Expert 1: Same? tensor(False, device='cuda:0') | Error tensor(5.3712, device='cuda:0')
==== Expert 2: Same? tensor(False, device='cuda:0') | Error tensor(3.8386, device='cuda:0')
==== Expert 3: Same? tensor(False, device='cuda:0') | Error tensor(5.1810, device='cuda:0')

Experiment 2: Check two implementations are identical

If the implementations are identical, they should produce the same output when the parameters are the same. Here we copy the expert parameters from the efficient implementation to the original version.

def copy_params_from_layers_to_experts(layers, experts):
    for lid, (weight, bias, _) in enumerate(layers):
        for eid, e in enumerate(experts):
            e[int(lid*2)].weight.data = weight[eid].data.t()
            e[int(lid*2)].bias.data = bias[eid].data.squeeze()

print("After Copy")
copy_params_from_layers_to_experts(layers, experts)
out_experts = forward_experts(experts, data)
compare_outputs(out_layers, out_experts)
After Copy
==== Expert 0: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')
==== Expert 1: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')
==== Expert 2: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')
==== Expert 3: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')

Furthermore, we should check that the backward pass is the same, e.g., both versions produce the same gradients given the same parameters, input, and loss function. To do this, we blend the experts outputs using random weights (blend_outputs) and compute the gradient using random targets (compare_gradients).

def blend_outputs(out_layers, out_experts):
    blending_weights = torch.softmax(torch.randn(num_experts).to(device), dim=-1)
    blending_weights = blending_weights[:, None, None]

    blended_out_layers = out_layers.mul(blending_weights).sum(dim=0)
    blended_out_experts = torch.stack(out_experts, dim=0).mul(blending_weights).sum(dim=0)
    print("Outputs same?", torch.isclose(blended_out_layers, blended_out_experts).all())
    return blended_out_layers, blended_out_experts

def compare_gradients(blended_out_layers, blended_out_experts):
    answers = torch.randn((batch_size, expert_output_size)).to(device)
    (blended_out_layers - answers.clone()).pow(2).mean().backward(retain_graph=True)
    (blended_out_experts - answers.clone()).pow(2).mean().backward(retain_graph=True)
    
    all_close = torch.Tensor([True]).bool().to(device)
    for lid, (weight, bias, _) in enumerate(layers):
        for eid, e in enumerate(experts):
            weights_close = torch.isclose(weight.grad[eid].t(), e[int(lid*2)].weight.grad).all()
            biases_close = torch.isclose(bias.grad[eid].squeeze(), e[int(lid*2)].bias.grad).all()
            all_close = all_close and weights_close and biases_close

    print("Gradients same?", all_close)
blended_out_layers, blended_out_experts = blend_outputs(out_layers, out_experts)
compare_gradients(blended_out_layers, blended_out_experts)
Outputs same? tensor(True, device='cuda:0')
Gradients same? tensor(True, device='cuda:0')

Experiment 3: Check computation time

Finally, we verify that the efficient implementation is much faster than the naive implementation.

Recall that batch size is 32, we use 4 experts, and each expert has 4 layers. The efficient implementation is almost 4 times faster in the forward pass and ~3 times faster in the backward pass.

%timeit forward_layers(layers, data)
%timeit forward_experts(experts, data)
307 µs ± 8.74 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.21 ms ± 6.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
answers = torch.randn((batch_size, expert_output_size)).to(device)
%timeit (blended_out_layers - answers.clone()).pow(2).mean().backward(retain_graph=True)
%timeit (blended_out_experts - answers.clone()).pow(2).mean().backward(retain_graph=True)
682 µs ± 8.92 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.9 ms ± 80 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

For 8 experts, the efficient implementation is 8 times faster.

num_experts = 8

layers = init_layers()
experts = init_experts()
copy_params_from_layers_to_experts(layers, experts)

%timeit forward_layers(layers, data)
%timeit forward_experts(experts, data)
396 µs ± 3.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.35 ms ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Functorch vmap Implementation

Functorch is a library for JAX-like composable function transforms in PyTorch. As of PyTorch 1.13, functorch is now included in the PyTorch binary.

Functorch has a tutorial on model ensembling, which does pretty much the same thing as what we are calling mixture-of-experts. Following the tutorial, MoE can be implemented using vmap and combine_state_for_ensemble, as such:

from functorch import vmap
from functorch import combine_state_for_ensemble

fmodel, params, buffers = combine_state_for_ensemble(experts)
for p in params: p.requires_grad_()

vmap_data = data.expand((num_experts, *data.shape))
out_vmap = vmap(fmodel)(params, buffers, vmap_data)

From the colab notebook, we can see that the vmap implementation is functionally identical to our previous implementations. However, the forward pass is about 3x slower than our efficient implementation, while the backward pass is roughly the same.

Interestingly, the forward pass time is consistent when the number of experts is increased from 4 to 8, staying around 600 microseconds. This shows that combine_state_for_ensemble is probably similar to the efficient MoE under the hood, but has slightly more overhead (for being more flexible?).

Final Thoughts

With a little modification, we can implement Mode-Adaptive Neural Networks (MANN) efficiently using a very similar formulation. MANN is a mixture-of-experts style architecture. However, instead of blending the expert outputs, it blends the parameters of the experts according to the output of a gating module.

If the experts are MLPs, then blending the parameters of the experts is the same as blending the output of each layer (before activation). The complete implementation of a MANN-style MoE module is below. Compared to before, we added a gating module and the only change in the forward function is the additional .mul() and .sum().

This is the basis of what we called Layer-wise MoE, and the same idea was used in Motion VAEs. Check out the projects for more detail.

class MixedActor(nn.Module):
    def __init__(self):
        super().__init__()

        expert_input_size = ...
        gate_input_size = ...
        output_size = ...
        hidden_size = ...

        self.layers = [
            (
                nn.Parameter(torch.empty(num_experts, expert_input_size, hidden_size)),
                nn.Parameter(torch.zeros(num_experts, 1, hidden_size)),
                torch.relu,
            ),
            (
                nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size)),
                nn.Parameter(torch.zeros(num_experts, 1, hidden_size)),
                torch.relu,
            ),
            (
                nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size)),
                nn.Parameter(torch.zeros(num_experts, 1, hidden_size)),
                torch.relu,
            ),
            (
                nn.Parameter(torch.empty(num_experts, hidden_size, output_size)),
                nn.Parameter(torch.zeros(num_experts, 1, output_size)),
                torch.tanh,
            ),
        ]

        for index, (weight, bias, activation) in enumerate(self.layers):

            for w in weight:
                nn.init.orthogonal_(w, gain=1.0)

            self.register_parameter(f"w{index}", weight)
            self.register_parameter(f"b{index}", bias)

        # Gating network
        self.gate = nn.Sequential(
            init_r_(nn.Linear(gate_input_size, hidden_size)),
            nn.ELU(),
            init_r_(nn.Linear(hidden_size, hidden_size)),
            nn.ELU(),
            init_r_(nn.Linear(hidden_size, hidden_size)),
            nn.ELU(),
            init_r_(nn.Linear(hidden_size, num_experts)),
        )

    def forward(self, x):
        coefficients = F.softmax(self.gate(x), dim=1).t().unsqueeze(-1)
        out = x.clone()

        for (weight, bias, activation) in self.layers:
            out = activation(
                out.matmul(weight)  # (N, B, H), B = Batch, H = hidden
                .add(bias)  # (N, B, H)
                .mul(coefficients)  # (B, H)
                .sum(dim=0)
            )

        return out