Hung Yu Ling

Fast mixture-of-experts in PyTorch


Standard Mixture-of-Experts

A mixture-of-experts (MoE) is a ensemble of neural networks, or experts, with the same input and output interfaces. A mixture-of-experts approach is a convenient way to scale up the number of parameters in a model at a small overhead cost. MoE also allows us to estimate the variance of the model prediction. Typically, the experts have the same architecture but are initialized with different weights.

We start by defining a few variables. Feel free to follow along in a Jupyter notebook.

import torch
import torch.nn as nn

torch.set_default_dtype(torch.float64)
device = "cuda" if torch.cuda.is_available() else "cpu"

num_experts = 4
expert_input_size = 60
expert_output_size = 20
hidden_size = 256
batch_size = 32

Naive MoE Implementation

We first create num_experts = 4 feed-forward neural networks in a loop using PyTorch’s nn.Sequential. This is the naive baseline implementation—since each expert is an independent nn.Module, we need to forward pass num_experts number of times every time.

We will see soon that for multilayer perceptron (MLP) models, there is an efficient and elegant way to implement mixture-of-experts.

def init_experts():
    experts = [
        nn.Sequential(
            nn.Linear(expert_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, expert_output_size),
            nn.Tanh(),
        ).to(device)
        for i in range(num_experts)
    ]
    return experts

def forward_experts(experts, data):
    return [e(data.clone()) for e in experts]

Efficient MoE Implementation

The key idea to implement MoE efficiently is to initialize the model layer-by-layer. In each layer, we need num_experts times of the parameters, e.g., weights and biases. During the forward pass, we loop through all layers to produce the results.

At first glance, we are trading a loop over num_experts for a loop over number of layers. However, since majority of time is spent computing matmul, the latter option can be parallelized more effectively.

def init_layers():
    layers = [
        (
            nn.Parameter(torch.empty(num_experts, expert_input_size, hidden_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, hidden_size).to(device)),
            torch.relu
        ),
        (
            nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, hidden_size).to(device)),
            torch.relu,
        ),
        (
            nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, hidden_size).to(device)),
            torch.relu,
        ),
        (
            nn.Parameter(torch.empty(num_experts, hidden_size, expert_output_size).to(device)),
            nn.Parameter(torch.zeros(num_experts, 1, expert_output_size).to(device)),
            torch.tanh,
        ),
    ]

    for index, (weight, bias, activation) in enumerate(layers):
        # Initialize each expert separately
        for w in weight:
            nn.init.orthogonal_(w, gain=1.0)
            
    return layers

def forward_layers(layers, data):
    out_layers = data.clone()
    for (weight, bias, activation) in layers:
        out_layers = activation(out_layers.matmul(weight).add(bias))
    return out_layers

Experiment 1

The compare_outputs function checks if the output of the experts between the two implementations are the same. By default, the parameters are all randomly initialized, so we expect the output to be different.

def compare_outputs(out_layers, out_experts):
    for eid, (ol, oe) in enumerate(zip(out_layers, out_experts)):
        print(f"==== Expert {eid}:", end=" ")
        with torch.no_grad():
            print("Same?", torch.isclose(ol, oe).all(), end=" | ")
            print("Error", (ol - oe).norm())

layers = init_layers()
experts = init_experts()

data = torch.randn((batch_size, expert_input_size)).to(device)

print("Before Copy")
out_layers = forward_layers(layers, data)
out_experts = forward_experts(experts, data)
compare_outputs(out_layers, out_experts)
Before Copy
==== Expert 0: Same? tensor(False, device='cuda:0') | Error tensor(4.2092, device='cuda:0')
==== Expert 1: Same? tensor(False, device='cuda:0') | Error tensor(5.3712, device='cuda:0')
==== Expert 2: Same? tensor(False, device='cuda:0') | Error tensor(3.8386, device='cuda:0')
==== Expert 3: Same? tensor(False, device='cuda:0') | Error tensor(5.1810, device='cuda:0')

Experiment 2: Check two implementations are identical

If the implementations are identical, they should produce the same output when the parameters are the same. Here we copy the expert parameters from the efficient implementation to the original version.

def copy_params_from_layers_to_experts(layers, experts):
    for lid, (weight, bias, _) in enumerate(layers):
        for eid, e in enumerate(experts):
            e[int(lid*2)].weight.data = weight[eid].data.t()
            e[int(lid*2)].bias.data = bias[eid].data.squeeze()

print("After Copy")
copy_params_from_layers_to_experts(layers, experts)
out_experts = forward_experts(experts, data)
compare_outputs(out_layers, out_experts)
After Copy
==== Expert 0: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')
==== Expert 1: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')
==== Expert 2: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')
==== Expert 3: Same? tensor(True, device='cuda:0') | Error tensor(0., device='cuda:0')

Furthermore, we should check that the backward pass is the same, e.g., both versions produce the same gradients given the same parameters, input, and loss function. To do this, we blend the experts outputs using random weights (blend_outputs) and compute the gradient using random targets (compare_gradients).

def blend_outputs(out_layers, out_experts):
    blending_weights = torch.softmax(torch.randn(num_experts).to(device), dim=-1)
    blending_weights = blending_weights[:, None, None]

    blended_out_layers = out_layers.mul(blending_weights).sum(dim=0)
    blended_out_experts = torch.stack(out_experts, dim=0).mul(blending_weights).sum(dim=0)
    print("Outputs same?", torch.isclose(blended_out_layers, blended_out_experts).all())
    return blended_out_layers, blended_out_experts

def compare_gradients(blended_out_layers, blended_out_experts):
    answers = torch.randn((batch_size, expert_output_size)).to(device)
    (blended_out_layers - answers.clone()).pow(2).mean().backward(retain_graph=True)
    (blended_out_experts - answers.clone()).pow(2).mean().backward(retain_graph=True)
    
    all_close = torch.Tensor([True]).bool().to(device)
    for lid, (weight, bias, _) in enumerate(layers):
        for eid, e in enumerate(experts):
            weights_close = torch.isclose(weight.grad[eid].t(), e[int(lid*2)].weight.grad).all()
            biases_close = torch.isclose(bias.grad[eid].squeeze(), e[int(lid*2)].bias.grad).all()
            all_close = all_close and weights_close and biases_close

    print("Gradients same?", all_close)
blended_out_layers, blended_out_experts = blend_outputs(out_layers, out_experts)
compare_gradients(blended_out_layers, blended_out_experts)
Outputs same? tensor(True, device='cuda:0')
Gradients same? tensor(True, device='cuda:0')

Experiment 3: Check computation time

Finally, we verify that the efficient implementation is much faster than the naive implementation.

Recall that batch size is 32, we use 4 experts, and each expert has 4 layers. The efficient implementation is almost 4 times faster in the forward pass and ~3 times faster in the backward pass.

%timeit forward_layers(layers, data)
%timeit forward_experts(experts, data)
307 µs ± 8.74 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.21 ms ± 6.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
answers = torch.randn((batch_size, expert_output_size)).to(device)
%timeit (blended_out_layers - answers.clone()).pow(2).mean().backward(retain_graph=True)
%timeit (blended_out_experts - answers.clone()).pow(2).mean().backward(retain_graph=True)
682 µs ± 8.92 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.9 ms ± 80 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

For 8 experts, the efficient implementation is 8 times faster.

num_experts = 8

layers = init_layers()
experts = init_experts()
copy_params_from_layers_to_experts(layers, experts)

%timeit forward_layers(layers, data)
%timeit forward_experts(experts, data)
396 µs ± 3.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.35 ms ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Final Thoughts

With a little modification, we can implement Mode-Adaptive Neural Networks (MANN) efficiently using a very similar formulation. MANN is a mixture-of-experts style architecture. However, instead of blending the expert outputs, it blends the parameters of the experts according to the output of a gating module.

If the experts are MLPs, then blending the parameters of the experts is the same as blending the output of each layer (before activation). The complete implementation of a MANN-style MoE module is below. Compared to before, we added a gating module and the only change in the forward function is the additional .mul() and .sum().

This is the basis of what we called Layer-wise MoE, and the same idea was used in Motion VAEs. Check out the projects for more detail.

class MixedActor(nn.Module):
    def __init__(self):
        super().__init__()

        expert_input_size = ...
        gate_input_size = ...
        output_size = ...
        hidden_size = ...

        self.layers = [
            (
                nn.Parameter(torch.empty(num_experts, expert_input_size, hidden_size)),
                nn.Parameter(torch.zeros(num_experts, 1, hidden_size)),
                torch.relu,
            ),
            (
                nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size)),
                nn.Parameter(torch.zeros(num_experts, 1, hidden_size)),
                torch.relu,
            ),
            (
                nn.Parameter(torch.empty(num_experts, hidden_size, hidden_size)),
                nn.Parameter(torch.zeros(num_experts, 1, hidden_size)),
                torch.relu,
            ),
            (
                nn.Parameter(torch.empty(num_experts, hidden_size, output_size)),
                nn.Parameter(torch.zeros(num_experts, 1, output_size)),
                torch.tanh,
            ),
        ]

        for index, (weight, bias, activation) in enumerate(self.layers):

            for w in weight:
                nn.init.orthogonal_(w, gain=1.0)

            self.register_parameter(f"w{index}", weight)
            self.register_parameter(f"b{index}", bias)

        # Gating network
        self.gate = nn.Sequential(
            init_r_(nn.Linear(gate_input_size, hidden_size)),
            nn.ELU(),
            init_r_(nn.Linear(hidden_size, hidden_size)),
            nn.ELU(),
            init_r_(nn.Linear(hidden_size, hidden_size)),
            nn.ELU(),
            init_r_(nn.Linear(hidden_size, num_experts)),
        )

    def forward(self, x):
        coefficients = F.softmax(self.gate(x), dim=1).t().unsqueeze(-1)
        out = x.clone()

        for (weight, bias, activation) in self.layers:
            out = activation(
                out.matmul(weight)  # (N, B, H), B = Batch, H = hidden
                .add(bias)  # (N, B, H)
                .mul(coefficients)  # (B, H)
                .sum(dim=0)
            )

        return out