Neural Networks 101: Part 15 - Deepseek


This blog post will step through the Deepseek V3 Transformer Architecture with coded examples.
Input
BPE (Byte Pair Encoding)
- BPE is a common tokenizing strategy used in LLMs such as GPT and BERT.
- BPE is a greedy algorithm that finds the most common subtokens to build a vocabulary.
How BPE Works
- Initialize the symbols
- Each word is initially split into individual characters and end-of-word markers
</w>
. - It then counts all adjacent symbol pairs.
lower</w>
l o
o w
w e
e r
r </w>
- Count Symbol Pair Frequencies
- Find the count of the token pairs.
pair_freq = {
("l", "o"): 5,
("o", "w"): 3,
...
}
- Merge the most frequent pair
If l o
is frequent, it becomes lo
. This continues until the max vocabulary size is achieved.
- Update the Vocabulary
vocab = {
"lo", ...
}
- Repeat until max size reached
- This restarts the steps until the max vocabulary size is reached or there is nothing left to merge.
The final vocabulary might look like:
bpe_vocab = {
"low": 0,
"er" : 1,
"est": 2,
"ne" : 3,
"new": 4
}
Why is BPE used?
- No OOV (out of vocabulary) problem, all text is representable if the vocabulary corpus was sufficiently large enough
- Simple and language agnostic
Example use of BPE
The transformer tokenizer library has a BPE class:
- The
BpeTrainer
is the class that executes counting the symbol pairs and applies the merges. - The
BPE
class is a wrapper that stores the vocabulary, the special token mapping and the rules.
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers import normalizers
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFD, Lowercase, StripAccents
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])
tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], vocab_size=5000)
Token Embeddings
This step converts token ids to a Token Embedding:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
If we take the below example, it will:
- Convert two token ids as input and create a Token Embedding to represent the token ids
input_ids = torch.tensor[[42, 138]]
inputs_embeds = self.embed_tokens(input_ids)
The Token Embedding is also called a Hidden State, and this is the term that will be used to show the journey of the Token Embedding through the DeepseekV3 Decoder Layer.
RMSNorm
To recap, the process of Normalization in Transformers is to stabilize the equations of the Neural Network.
Without normalization, the numbers in the calculations can get too big, causing exploding/vanishing gradients.
LayerNorm keeps the numbers within a mean and variance, a “reasonable” range, so that the network doesn’t “blow up”.
RMSNorm is a simplified version of LayerNorm:
- It removes the mean centering
- Keeps the re-scaling step
- Is faster and more memory-efficient
In the equation, it just removes the mean variable and learned bias.
Where:
- : the input vector
- : learnable scalable parameter
- : a constant for numerical stability
- + : the root mean square of the input added with epsilon
@use_kernel_forward_from_hub("RMSNorm")
class DeepseekV3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DeepseekV3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
Where:
hidden_states = hidden_states.to(torch.float32)
- The conversion to float32 has better numerical precision for computation to avoid inaccuracies in variance
variance = hidden_states.pow(2).mean(-1, keepdim=True)
- variance is the computation of the mean of squares
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- Applies RMS normalization to the hidden state
return self.weight * hidden_states.to(input_dtype)
- Multiplies the (the weight) by the hidden states, cast back to the original dtype
RoPE
RoPE (Rotary Positional Encoding) calculates positional information by rotating the Query and Key vectors in the attention mechanism using the sine and cosine functions.
This is the same principle as the Positional Encoding calculation and addition with the Token Embedding found in earlier Transformers.
RoPE improves on the space complexity and computation since the positions are encoded in-place without needing additional embedding vectors for the positions.
Formula
- Generate K,Q,V
- Like in previous attention mechanisms, K,Q and V are linear layers of the input Token Embedding repeated
- Split Q and K into 2D vectors (pairs)
Q = [(q_0, q_1), (q_2, q_3), … (q_{d-2}, q_d{d-1})]
Q is split into 2D pairs that only represents the Q vector for that token
The rotation formula is applied to each (x, y) using the position-dependent angle
The above is applied to K
- Apply
RoPE
to the 2D Pairs of Q and K, this rotates each position’s vector using a position-based angle.
RoPE
is never applied to V, only to Q and K, because these are the vectors that are used to compute attention scores (via dot product), so relative position matters.
- is the rotation angle, calculated from the token’s position and the head dimension (head dimension is the size of the dimension per head, will be covered in [[#Multi-Head Latent Attention]]), similar to the original Positional Embedding in encoder-decoders.
Where:
- :
- : the position index of the token
- e.g. [“The”, “cat”, “sat”] [“The” -> p = 0, “cat” -> p = 1, “sat” -> p = 2]
- : the output rotation angle for the particular token
- The final output of the 2D Vectors contains the positional encoding
The input:
- Q = [q0, q1, … qd]
- K = [k0, k1, … kd]
Split Q and K into 2D Pairs:
- Q = [(q0, q1), (q2, q3), …, (qd-2, qd-1)]
- K = [(k0, k1), (k2, k3), …, (kd-2, kd-1)]
Apply the rotation on each pair:
The final output will be vectors of rotated pairs which will have positional awareness.
To summarize in more simple terms, each Query and Key pair will be rotated by an angle based on the token’s position (e.g. ) and the pair’s index (e.g. ) - making it position-aware.
- Compute the attention scores using the dot product of each rotated Q and K
Where:
- : are the original query and key vectors for tokens at positions and
- : is the query vector rotated by RoPE using position
- : is the key vector rotated by RoPE using position j
- : is the dot product between the rotated vector
- : is the head dimension (length of the query/key vectors)
RoPE Implementation
class DeepseekV3RotaryEmbedding(nn.Module):
def __init__(self, config: DeepseekV3Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
This code sample has an implementation of RoPE that only calculates the rotated embedding values.
inv_freq
: this is a 1D vector that contains the product of the formula:
Full formula:
inv_freq
=
Where:
- The upper bound is
d/2-1
since we are splitting the vectors into 2D Pairs, we just need one angle per pair.
inv_freq_expanded = self.inv_freq[None, :, None] \
.float() \
.expand(position_ids.shape[0], -1, 1) \
.to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
In the forward pass, we need to expand the inv_freq
from a 1D vector in order to take advantage of [[AI Cheat Sheet#Batch Matrix Multiplication|batch matrix multiplication using CUDA]].
This creates a 3D vectors to represent:
inv_freq_expanded
=[batch_size, dim, 1]
position_ids_expanded
=[batch_size, 1, seq_len]
These shapes are setup so that the element-wise multiplication:
angle = position_ids_expanded * inv_freq_expanded
``
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
We can see that after expanding the frequencies and position ids, we can perform batch multiplication using the @
symbol, using CUDA to perform parallelized matrix multiplication.
The returned values of cos and sin contain the rotation coefficients, required for rotating the pair values at a later stage for Q and K vectors.
Grouped Query Attention / Multi-Head Latent Attention
This is a specific attention mechanism used in Deepseek.
Grouped Query Attention / Multi-Head Latent Attetnion, builds on Multi-Head Attention and Multi-Query Attention.
In GQA/MHLA, is a hybrid approach of MHA and MQA.
MHA creates a K,Q,V for every head
- High expressiveness but high memory and compute cost
MQA reuses the same K,V across all heads but keeps Q independent on each head
- Faster and less compute intensive but may degrade quality
GQA/MHLA:
- Divides X heads in to G groups
- 8 heads -> 2 Groups of 4 heads
- Each group shares the same K and V
- Independent Q per head inside the group
Group 1: Head 1-4 share K1, V1
Group 2: Head 5-8 share K2, V2
- Using the example above:
- There are only 2x K and 2x V per group, all shared by the heads in the group
- Each head in each group has their own Q:
- Therefore, there are 8x Q
- Q x 4 (per head) x 2 (groups)
MHLA has the advantage that it is faster than MHA while still retaining more more expressiveness than MQA.
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
batch_size, seq_length = hidden_states.shape[:-1]
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
We’ll step through a forward pass of a Deepseek Decoder Layer Implementation.
Given the
hidden_states
input, with a shape of[batch_size, seq_length, hidden_size]
, we create thequery_shape
andkey_shape
This creates a shape for query and key of:
[batch_size, seq_length, num_heads, dimension_size]
num_heads
is implicitly calculated usinghidden size / qk_head_dim
- e.g.
num_heads = hidden_dim / qk_head_dim == 12 = 768 / 64
- e.g.
Notice
self.qk_nope_head_dim + self.v_head_dim
, this indicates that part of the key will be reserved for NO rope application and the latter will have RoPE applied
...
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_states
: the collection of all Q vectors for every attention head accounting for all groupsself.q_a_proj(hidden_states)
: a linear layer that projects the hidden_states to a Linear Layer of the sizeq_lora_rank
. This essentially acts as compression of the hidden statesself.q_a_layernorm(...)
: a RMSNorm for the size of the compressedq_lora_rank
tensor- Normalizing the compressed hidden state helps improve stability and allows low-rank adapters to learn effectively
self.q_b_proj(...)
: projects the compressed tensor into a sizeself.num_heads * self.qk_head_dim
q_pass
andq_rot
are the Q states that are split.q_pass
will NOT have RoPE applied,q_rot
WILL HAVE RoPE applied.
...
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
cos, sin = position_embeddings
if self.config.rope_interleave: # support using interleaved weights for efficiency
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
else:
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)
compressed_kv
: is the lower rank projection for K and V- As with Q, it follows a similar procedure of splitting K into RoPE applicable parts and non RoPE applicable
- The rotary coefficients
cos
andsin
are applied toq_rot, k_rot
and then concatenated to create thequery_states
andkey_states
, ready for attention
...
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once("`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
- The K,Q,V with RoPE applied are passed to the attention mechanism
- The outputs are reshaped and projected back to the Hidden State size
Mixture of Experts
A Mixture of Experts (MOE) is a collection of Feed Forward Networks and a router that will pick a subset of experts that will transform the hidden state into an output.
Instead of one big Feed Forward network that processes all tokens, MOE uses many smaller experts that will have tokens routed to a specific expert.
This allows:
- Experts to learn a specific role (e.g. syntax, numbers, code)
- Each token only activated a few experts, reducing computational requirements
There are:
Dense MoE
: all experts will be activatedSparse MoE
: only a few experts will be activated
Typically, a MOE pass per token looks like this:
- In earlier layers (
layer_idx
), the input is routed to the Dense (FFN), this enables early learning of general patterns - Later layers are then sent to the
MoE
, this is to handle more complex and specialised areas using the top-k routing mechanism - Router decides top-k expert for a token (e.g. experts 2 and 5)
- The router creates probabilities for the token for each expert
- Tokens are sent to this k selected experts
- Each expert process the token independently
- The outputs are combined in a weighted sum
class DeepseekV3MoE(nn.Module):
"""
A mixed expert module containing shared experts.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.experts = nn.ModuleList(
[
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
for _ in range(config.n_routed_experts)
]
)
self.gate = DeepseekV3TopkRouter(config)
self.shared_experts = DeepseekV3MLP(
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
r"""
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
to not have to do a loop here (deepseek has 256 experts soooo yeah).
"""
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
expert_mask = expert_mask.permute(2, 0, 1)
for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)
if token_indices.numel() > 0:
expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, weighted_output)
# in original deepseek, the output of the experts are gathered once we leave this module
# thus the moe module is itelsf an IsolatedParallel module
# and all expert are "local" meaning we shard but we don't gather
return final_hidden_states.type(hidden_states.dtype)
def forward(self, hidden_states):
residuals = hidden_states
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states
We’ll go through the sample MoE implementation above.
Looking at the forward method:
- We use the
self.gate
function as the router, this will return the chosen topk indices and their weights, see [[#Topk Router]] hidden_states
are reshaped and sent to themoe
function- The
moe
function in this code sample is unoptimized, it will use a single loop instead of parallelizing the requests to experts- It creates an expert mask
- Loops through each expert, if the mask applied to the expert means the token_indices is greater than 0, meaning it should be activated, it will activate the expert
- It calls the experts forward method (remember an expert is just a FFN)
- The expert output is combined to create a weighted output and added to the final_hidden_states
Topk Router
class DeepseekV3TopkRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
@torch.no_grad()
def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
return topk_indices
def forward(self, hidden_states):
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
scores = router_logits.sigmoid()
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights
Above is the implementation for the TopK Router
.
Looking at the forward pass:
router_logits
is a linear layer of the hidden states and weightsscores
is the output of therouter_logits
through asigmoid
function, this will be the probabilities for each experttopk_indices
are the retrieved top_k indices by using grouping and score correctiontopk_weights
are the weights gathered, associated with thetopk_indice
- Normalization is applied if its set, to ensure the topk weights sum to 1
- A scaling factor is applied to the weights to maintain consistent magnitudes–