REALM: Retrieval-Augmented Language Model Pre-Training

Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, Ming-Wei Chang

ICML · 2020

realmpretrainingretrievalend-to-end

TL;DR

Introduces end-to-end training of retrieval-augmented language models, jointly learning what to retrieve and how to use retrieved information during pre-training.

Key Contribution

REALM shows that retrieval can be incorporated into language model pre-training itself. Rather than adding retrieval as a post-hoc component, the retriever is trained end-to-end with the language model, learning what knowledge to retrieve for the masked language modeling objective.

Architecture

Neural Knowledge Retriever

Given input x, retrieve relevant document z:

p(z|x) = exp(f(x,z)) / Σ_z' exp(f(x,z'))

Where f(x,z) = Embed_input(x) · Embed_doc(z)

Knowledge-Augmented Encoder

Combine input with retrieved document:

p(y|z,x) = Language_Model([x; z])

Overall: p(y|x) = Σ_z p(y|z,x) p(z|x)

End-to-End Training

Both retriever and LM trained jointly:

  • Retriever learns what's useful for MLM
  • LM learns to use retrieved context
  • Gradients flow through retrieval
  • Pre-Training

    Objective

    Masked language modeling with retrieval:

  • Mask tokens in input
  • Retrieve relevant documents
  • Predict masked tokens using input + retrieved
  • Backpropagate through everything
  • Async Index Refresh

    Challenge: Document embeddings change during training

    Solution:

  • Periodically re-encode and re-index documents
  • Use stale index between refreshes
  • Update every few hundred steps
  • Salient Span Masking

    Mask named entities and dates:

  • Forces model to retrieve world knowledge
  • More meaningful than random masking
  • Encourages knowledge-intensive retrieval
  • Fine-Tuning

    Open-Domain QA

  • Replace MLM head with QA head
  • Fine-tune on QA datasets
  • Retriever adapts to questions
  • Model learns to extract answers
  • Results

    State-of-the-art on:

  • Natural Questions
  • WebQuestions
  • CuratedTREC
  • Without any corpus-specific pretraining.

    Key Insights

    Retrieval as Latent Variable

    Treating z as latent enables:

  • Unsupervised retriever training
  • Soft attention over documents
  • End-to-end optimization
  • Knowledge Externalization

    Benefits of stored knowledge:

  • Interpretable (see what was retrieved)
  • Updatable (change corpus, not model)
  • Scalable (add more documents)
  • Pre-training Matters

    Retrieval-augmented pre-training:

  • Teaches model to use retrieval
  • Better than adding retrieval at fine-tuning
  • Transfers across tasks
  • Comparison to DPR

    | Aspect | DPR | REALM |

    |--------|-----|-------|

    | Training | Supervised | Self-supervised |

    | Data needed | QA pairs | Raw text |

    | Retriever | Fixed during LM | Joint training |

    | Pre-training | No | Yes |

    Relevance to Agent Memory

    Learned Memory Access

    REALM principles for agents:

  • Learn when memory is useful
  • Learn what to retrieve
  • End-to-end optimization
  • Memory Pre-training

    Could pre-train agents to:

  • Use memory effectively
  • Know what's worth storing
  • Retrieve appropriately
  • Implementation Challenges

    Compute Requirements

  • Need to index large corpus
  • Periodic re-indexing expensive
  • Marginalization over documents
  • Engineering Complexity

  • Async index updates
  • Gradient through retrieval
  • Large-scale infrastructure
  • Limitations

  • Computational cost
  • Requires large pre-training corpus
  • Fixed retrieval granularity
  • Single retrieval per prediction
  • Citation

    @inproceedings{guu2020realm,

    title={REALM: Retrieval-Augmented Language Model Pre-Training},

    author={Guu, Kelvin and Lee, Kenton and Tung, Zora and Pasupat, Panupong and Chang, Ming-Wei},

    booktitle={International Conference on Machine Learning},

    year={2020}

    }