How is bidirectional information retrieved and generated in masked diffusion language models?

Understanding Bidirectional Information Retrieval in MDLMs with ROME

Summary

tl;dr

In this post, I will discuss the intermediate progress on using ROME to study bidirectional information retrieval in masked (discrete) diffusion language models.

Analysis in-progress!

Abstract

This research proposal aims to conduct the first mechanistic investigation of bidirectional retrieval and generation in masked diffusion language models. We focus on masked diffusion models (MDMs), which have shown surprising resistance to the “reversal curse” that plagues autoregressive architectures. The study will map the causal pathways of knowledge recall and analyze the internal representations learned by the model. By leveraging recent insights into the “factorization curse” - a more fundamental problem underlying the reversal curse - we seek to understand how non-autoregressive architectures overcome these limitations. This mechanistic investigation will be among the first to systematically examine the internal workings of masked diffusion language models, providing insights into their emerging capabilities for bidirectional logical reasoning. The findings will contribute to the development of more robust language models with improved reasoning abilities and more reliable knowledge representation.

Introduction

Large language models (LLMs) have demonstrated impressive capabilities across diverse tasks, but recent studies have identified a significant limitation called the “reversal curse.” When autoregressive models like GPT learn a fact such as “A is B,” they fail to automatically deduce the reverse relationship “B is A” . For example, if a model learns that “Paris is the capital of France,” it struggles to answer “The capital of France is __” correctly.

Recent work by reframes this as the “factorization curse” - a more fundamental problem where language models fail to learn the same joint distribution under different factorizations. This insight suggests the issue isn’t just about reversing relationships, but about a broader limitation in how next-token prediction objectives force models to learn specific factorizations of joint distributions, undermining their ability to reason bidirectionally.

Certain model architectures — particularly bidirectional encoder models like BERT and masked diffusion models (MDMs) — appear to be more resistant to this limitation (). Understanding why these architectural differences lead to better bidirectional reasoning capabilities could provide valuable insights for developing more robust AI systems that overcome the factorization curse. This research proposal aims to investigate the mechanisms behind bidirectional information retrieval in language models, with a specific focus on masked diffusion models (MDMs).

The reversal curse was first formally identified in , which demonstrated that autoregressive models fine-tuned on statements like “A is B” failed to generalize to “B is A.” The authors showed that models like GPT-3 (175B) and Llama-1 (7B) scored no better than random chance when evaluated on reversed relationships.

The Reversal Curse
Figure 1: The Reversal Curse

Kitouni et al. (2024) significantly expanded the understanding of this phenomenon by reframing it as the “factorization curse” - a fundamental limitation of the next-token prediction objective used in most LLMs . Through controlled experiments using their novel WikiReversal benchmark based on Wikipedia knowledge graphs, they demonstrated that this is an inherent failure mode that cannot be solved merely through scaling, reversed tokens, or even naive bidirectional-attention training. Their work identified factorization-agnostic objectives as a promising solution, showing significant improvements across various tasks of different complexity levels.

Ma et al. (2023) investigated this problem specifically in the context of model editing, introducing the “Bidirectional Assessment for Knowledge Editing” (BAKE) benchmark and the “Bidirectionally Inversible Relationship moDeling” (BIRD) method . Their work revealed that while existing editing methods can effectively recall facts in the direction they were trained on, they perform poorly in the reverse direction. For instance, LLaMA-2 edited with state-of-the-art ROME could recall 99.70% of editing facts in the forward direction but only 0.26% in the reverse direction.

Wu and Wang (2023) explored this phenomenon further, comparing autoregressive decoder-only models (like GPT) with bidirectional encoder models (like BERT) . They found that BERT-style models were largely immune to the reversal curse for basic logical deductions, suggesting that architectural differences play a critical role.

Recent advances in masked diffusion models have shown particularly promising results. Nie et al. (2024) demonstrated that a 1.1B parameter MDM could break the reversal curse that much larger autoregressive models struggle with . Similarly, research on LLaDA (Large Language Diffusion Models) has shown that diffusion-based approaches effectively address bidirectional reasoning challenges and even outperform GPT-4o on certain reversal tasks .

A primer on masked discrete diffusion

Research Objectives

This study aims to:

  1. Map the causal pathways of knowledge recall to understand how information flows when retrieving facts in forward versus reverse directions.
  2. Analyze learned representations to determine if and how they encode bidirectional logical relationships, identifying any symmetries in forward and backward representation of concepts.

Methodology

Models

Table 1: Model configuration parameters used in experiments.
Model Type Masked Diffusion Autoregressive
Base Architecture Llama2 Llama2
Transformer Blocks 20 20
Attention Heads 14 14
n_embed 1792 1792
embed_dim 7168 7168
vocab_size 32000 32000
Sampling temperature 0. 0.
Sampling algorithm Greedy Greedy
Table 2: Sampling configuration for SMDM-1028M (Masked Diffusion)
Config Value
Denoising Steps 16
Context Length 52
CFG Scale 0.8

The models chosen in this study are pre-trained on the SlimPajama Dataset . These models are based on the Llama2 architecture and are configured with settings referenced in Table 1. The sampling config for the masked diffusion language model SMDM-1028M has been borrowed from as well and is shown in Table 2. The denoising steps indicate the number of steps over which the masked model response is unmasked. The context length refers to the total length of the sequence given as input to model, which includes both the prompt and the unmasked reponse Fig 2. The CFG scale is the factor with which the classifier free guidance is scaled and applied for conditional generation.

Denoising response sequence during inference in masked diffusion language models
Figure 2: The prompt token is padded with masked tokens which are iteratively denoised during sampling from a masked diffusion language model.

Dataset

The reversal curse can be studied with factual recall on:

The masked diffusion model was first prompted to generate single-token completions on Wikipedia knowledge prompts, which the model is trained on as part of SlimPajama. The last token did not elicit the most informative tokens. A subset of the responses can be seen in Fig 3.

Wikipedia Knowledge Prompt Completions
Figure 3: Generations by SMDM-1028M on Wikipedia knowledge prompts yields high entropy (less information in first generated token) and could not be used as a controlled database to test the reversal curse on.

The models were then fine tuned on the Reversal curse dataset and attained high accuracy Table 3, so we chose this dataset. The samples from the dataset are depicted in Fig. 4.

Dataset
Figure 4: The reversal curse dataset.

Causal Tracing

\cite{Meng2022LocatingAE} introduced causal tracing as a method to identify storage and processing of factual knowledge in language models. The technique aims to isolate the causal effect of individual states within the network while processing factual statements, essentially mapping the information flow path through the model.

The method aims to understand the model’s prediction, focusing on identifying key neurons and layers involved in recalling factual associations.

Tracing Information Flow

use causal mediation analysis to quantify the contribution of intermediate variables (hidden states) in the model’s causal graph to a correct factual prediction .

To calculate each state’s contribution, they observe the model’s internal activations during three runs:

  1. Clean Run: A factual prompt $x$ is passed into the model $G$, and all hidden activations are collected.
  2. Corrupted Run: The subject $s$ is obfuscated from $G$ before the network runs to damage the prediction.
  3. Corrupted-with-Restoration Run: The model runs computations on the noisy embeddings as in the corrupted baseline, except at some token $i$ and layer $l$, the model is forced to output the clean state $h^{(l)}_i$.

Indirect Effect (IE): The indirect effect of a specific mediating state $h^{(l)}_i$ is defined as the difference between the probability of $o$ under the corrupted version and the probability when that state is set to its clean version, while the subject remains corrupted.

Average Indirect Effect (AIE): Averaging over a sample of statements, the average total effect (ATE) and average indirect effect (AIE) is obtained for each hidden state variable.

Extracting Factual Representations

Meng et al., 2022 build upon the localized factual association hypothesis to devise a method to localize factual associations within transformer models. By locating the layer and weight matrix where a new fact can be injected with minimal interference, ROME provides evidence that factual knowledge is encoded in specific MLP projection layers as linear associations between key and value vectors. The proposed method by is explained briefly for context.

Extracting the Key $k_{\ast}$

To identify the memory location corresponding to the subject $s$ in the factual triple $(s, r, o)$ a key vector $k_{\ast}$ is constructed with the underlying assumption being certain activations within the MLP are responsible for retrieving facts associated with a given subject. To find a robust key, the activation vectors are averaged across multiple short text prefixes that lead into and end with the subject token $s$.

\[k_{\ast} = \frac{1}{N} \sum_{j=1}^N k(x_j + s)\]

where $k(x)$ is derived from the layer $l^*$’s MLP activations after the non-linearity is applied.

Extracting the Value $v_{\ast}$

Meng et al., 2022 compute a value vector $v_{\ast}$ that encodes a new object $o^*$ (i.e., the fact to be recalled) as a property of the subject $s$. This is formulated as an optimization problem with a linear combination of two objectives:

  1. Maximize the likelihood that the model, when prompted with subject $s$ and relation $r$, predicts the new object $o^*$.
  2. Minimize the KL divergence between the model’s original predictions for $s$ and its predictions after editing, to prevent disrupting the subject’s broader semantic identity.

This results in the following loss objective:

\[\mathcal{L}(z) = - \frac{1}{N} \sum_{j=1}^N \log \mathbb{P}_{G(m_i^{(l^*)} := z)}(o^* \mid x_j + p) + D_{\mathrm{KL}}\left( \mathbb{P}_{G(m_{i'}^{(l^*)} := z)}[x \mid p'] \,\Vert\, \mathbb{P}_G[x \mid p'] \right)\]

where $z$ is optimized to serve as $v_{\ast}$, the value that, when output by the MLP, causes the model to recall the new object $o^*$ in response to the original factual prompt $p$. $p’$ represents the set of adjacent prompts (of the form “{subject} is a”) which aim to preserve the model’s understanding of the subject’s essence.

In contrast, we aim to determine the representation of the original object $o$ when preceded with subject $s$ in the prompt. So we update the objective function to account for $o$ instead of $o^*$ and the rest remains the same:

\[\mathcal{L}(z) = - \frac{1}{N} \sum_{j=1}^N \log \mathbb{P}_{G(m_i^{(l^*)} := z)}(o \mid x_j + p) + D_{\mathrm{KL}}\left( \mathbb{P}_{G(m_{i'}^{(l^*)} := z)}[x \mid p'] \,\Vert\, \mathbb{P}_G[x \mid p'] \right)\]

Meng et al., 2022 go a step further and apply a rank-one update to the projection matrix $W^{(l)}_{\text{proj}}$ of the MLP at layer $l^*$, in order to encode a new key–value pair into the model’s weights. This is beyond the scope of this research study.

Causal Tracing for Diffusion Language Models

Unlike autoregressive and masked language models, which predict all response tokens in a single forward pass; response generation in masked diffusion language models starts with a sequence of all masked tokens $x_1$ with fixed length $L$ (context length) is set and a mask predictor $f_\theta$ unmasks all tokens following a noise schedule $\alpha_t$ over $T$ denoising steps to generate the response $x_0$. Specifically, the masked diffusion model simulates a diffusion process from $t = 1$ (fully masked) to $t = 0$ (unmasked), predicting all masks simultaneously at each step with the possibility of flexible remasking strategies (Fig. 5).

Unmasking tokens via sampling during inference
Figure 5: RMasked diffusion models simulate a diffusion process from t = 1 (fully masked) to t = 0 (unmasked), predicting all masks simultaneously at each step with the possibility of flexible remasking strategies.

For masked diffusion language modeling, the individual response tokens (as seen in Fig. 6) can be unmasked at different time steps.

Response generation from masked diffusion language models
Figure 6: Response generation from masked diffusion language model.

During training, Masked diffusion models (MDMs) are trained to increased the likelihood of the data distribution for different masking conditions. During inference, the tokens can be masked with different sampling strategies to determine the order in which the response tokens are predicted. In the context of causal information flow analysis, this introduces both challenges and opportunities:

Causal tracing is run from time steps $t = t_1$ to $t = t_2$.

When retrieving information from a language model, the process acts like a key-based value lookup with the subject $s$ acting as a key and the object acting as a value to be extracted. Following , the key vectors in MLP layers corresponding to the last subject token in identified important MLP units (discovered by causal tracing) will be fixed and the objective function defined in equation 2 will be optimized to extract the representation of the object $o$.

This procedure will be done while trying to elicit knowledge in both forward (sample prompt: “The trailblazer known as Daphne Barrington was once”) and reverse (sample prompt: “Immersed in the world of directing the virtual reality masterpiece, A Journey Through Time.”) directions. The key and value representations in these cases would be interchanged between both runs. For example, the key in prompt 1 “Daphne Barrington” would act as the value in prompt 2. The aim is to compare the key and value representations of the same concept via dimensionality reduction techniques (SVD and PCA) to see if there exists a symmetry in which the information is encoded in the diffusion model.

Current Progress

We begin with a scaling experiment to study the impact of sampling steps on knowledge recall, measured with exact match accuracy . As the sampling steps are increased, there is an increase in knowledge recall in both forward and reverse directions (refer Table 3).

Please note that these accuracy values for Description2Person reverse and Person2Description reverse are different from the ones reported in . *We believe this is an error in reporting of the values since reverse recall is harder to model for language models, majorly because the perplexity of the text in description is usually higher than the text in person names of the Reversal Curse dataset .*
Table 3: Scaling sampling steps leads to better knowledge recall performance in forward and reverse directions.
Sampling Steps Description2Person Person2Description
Same (easy) Reverse (difficult) Same (difficult) Reverse (easy)
% Acc ↑ % Acc ↑ % Acc ↑ % Acc ↑
1 step 91.3 0 0 40.3
2 steps 97 0.7 2.3 79
4 steps 96.3 17 25 87.7
8 steps 97.3 27.3 37.7 87.3
16 steps 98 31 42 88.3
32 steps 97.7 29.7 43 90

This is a work-in-progress. More results coming soon!

The existing code for the study can be found at raishish/diffusion-interp.

Future Work

Acknowledgements

References

If you would like to cite this work, please use the following BibTeX entry:

@article{rai2025discrete-diffusion-rome,
  title={How is information retrieved and generated in masked diffusion language models?},
  author={Rai, Ashish},
  year={2025},
  month={July},
  url={https://raishish.github.io/blog/2025/discrete-diffusion-rome/}
}