Model-based reinforcement learnig with diffusion models
Diffusion models are best known for their image-generation abilities. Now, they are being used to learn world models for reinforcement learning systems.
Diffusion models are best known for their image generation abilities and their use in DALL-E, Midjourney, and Stable Diffusion.
But diffusion can be used for more than just generating images. A new study shows how diffusion can be used in model-based reinforcement learning.
Diffusion World Model (DWM), the technique introduced in the paper, uses a two-stage process to train model-based RL systems.
In the first stage, a DWM is trained on trajectories collected from the RL environment. The model is conditioned on the initial state and action and learns to predict several states of the environment into the future. This is in contrast to other model-based systems that predict one step at a time.
In the second stage, the system uses the DWM to train a model-free RL system through the actor-critic algorithm. The result is a very efficient RL system that outperforms both classic model-based and model-free RL systems.
Read more about DWM on TechTalks.
For more on AI research: