RecVAE — Recurrent 3D-Conv VAE for fMRI

Open-source · github.com/YuZh98/VAE-fMRI-Alzheimer

https://github.com/YuZh98/VAE-fMRI-Alzheimer/actions/workflows/tutorials.yml/badge.svg https://img.shields.io/badge/license-Apache--2.0-blue https://img.shields.io/badge/python-3.9%2B-blue https://img.shields.io/badge/pytorch-2.0%2B-ee4c2c https://colab.research.google.com/assets/colab-badge.svg

A recurrent 3D-conv VAE for resting-state fMRI volumes, in the same family as Kim et al. (2021) but with a linear latent transition solved in closed form by ridge regression rather than learned by SGD. Ships with a synthetic data path so you can run it without ADNI access; every example fits on a laptop CPU.

1
2
3
4
5
6
git clone https://github.com/YuZh98/VAE-fMRI-Alzheimer
cd VAE-fMRI-Alzheimer
python -m venv .venv && source .venv/bin/activate
pip install -e ".[dev]"
pytest -v                                              # 36 tests, ~5-20s, CPU-only
python tutorials/15_train_end_to_end/train_tiny.py     # synthetic, ~3s

Tutorials

tutorials/ has 18 hands-on lessons covering tensor shapes, 3D-conv arithmetic, the reparameterization trick, recurrent rollouts, alternating optimization, reproducibility, testing DL code, and research extensions.

Architecture

A 3D-convolutional VAE with a latent recurrence:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
                    ┌──────────────┐
   x_t ─────────────► 3D-CNN encode├──┐
   (1,91,109,91)    └──────────────┘    (B, 100)
                                      
                       ┌────────────────────────────┐
   h_{t-1} ───────────►│  hidden2mu, hidden2log_var 
   (B, 10)             └────────────────────────────┘
                                      
                                            ε ~ N(0, I)   (reparam)
                                 mu_h, log_var_h ────────────┐
                                                              
                                                      h_t = mu_h + σ_h ε   (B, 10)
                                                              
                                       z_s (subject noise) ───┤
                                                              
                                                      h_t + z_s  ─────────┐
                                                                          
                                                              ┌─────────────────────┐
                                                                3D-CNN decode      
                                                              └─────────────────────┘
                                                                          
                                                                          
                                                                  μ_t  (B,1,91,109,91)

   Linear temporal prior: g(h) = h F^T    used in loss term  h_t  g(h_{t-1})‖²

Loss

TermFormHow it’s optimized
loss1Per-volume reconstruction MSE / σ_x²SGD
loss2‖h_t − g(h_{t-1})‖² / σ_h² (temporal)SGD
loss_zλ_z · ‖z‖₁ (subject-noise sparsity)SGD
loss_Fρ · ‖F‖_F² (reported, not back-propped)Closed form (ridge)

Visuals

Loss curveReconstructionLatent trajectory
https://raw.githubusercontent.com/YuZh98/VAE-fMRI-Alzheimer/main/docs/assets/loss_curve.pnghttps://raw.githubusercontent.com/YuZh98/VAE-fMRI-Alzheimer/main/docs/assets/recon_slice.pnghttps://raw.githubusercontent.com/YuZh98/VAE-fMRI-Alzheimer/main/docs/assets/latent_trajectory.png

Citation

If you use this code in research or teaching, please cite via the GitHub “Cite this repository” button (driven by CITATION.cff).