Defining Statistical Models in Jax?

(statmodeling.stat.columbia.edu)

64 points | by hackandthink 4 days ago ago

10 comments

  • JHonaker 3 hours ago

    I'm very excited by the work being put in to make Bayesian inference more manageable. It's in a spot that feels very similar to deep learning circa mid-2010s when Caffe, Torch, and hand-written gradients were the options. We can do it, but doing anything more complicated than common model structures like hierarchical Gaussian linear models requires dropping out of the nice places and into the guts.

    I've had a lot of success with Numpyro (a JAX library), and used quite a lot of tools that are simpler interfaces to Stan. I've also had to write quite a few model-specific things from scratch by hand (more for sequential Monte Carlo than MCMC). I'm very excited for a world where PPLs become scalable and easier to use /customize.

    > I think there is a good chance that normalizing flow-based variational inference will displace MCMC as the go-to method for Bayesian posterior inference as soon as everyone gets access to good GPUs.

    Wow. This is incredibly surprising. I'm only tangentially aware of normalizing flows, but apparently I need to look at the intersection of them and Bayesian statistics now! Any sources from anyone would be most appreciated!

    • sarosh 3 hours ago

      Defer to other experts, but (briefly) normalizing flows are a method for constructing complex distributions by transforming a probability density through a series of invertible transformations. Normalizing flows are trained using a plain log-likelihood function, and they are capable of exact density evaluation and efficient sampling. See:

      Danilo Rezende and Shakir Mohamed. Variational inference with normalizing flows. In ICML, 2015. Link: https://bigdata.duke.edu/wp-content/uploads/2022/08/1505.057...

      Laurent Dinh, David Krueger, and Yoshua Bengio. Nice: Non-linear independent components estimation. In ICLR Workshop, 2015. Link: https://arxiv.org/pdf/1410.8516

      And for your direct question, the following paper "Efficient Bayesian Sampling Using Normalizing Flows to Assist Markov Chain Monte Carlo Methods" appears upon a superficial glance to be relevant. Link: https://arxiv.org/pdf/2107.08001

      • 1980phipsi an hour ago

        So it's like converting a normal distribution to log normal (and then back). But a more general way of thinking about it.

        Where does the name "normalizing flows" come from?

      • JHonaker 2 hours ago

        Thanks! I've read the first one before. I'll take a look at the other two!

    • legobmw99 3 hours ago

      The author links to https://arxiv.org/abs/2006.10343, which seems like a good place to start on normalizing flows for Bayes

  • sampo 3 hours ago
  • techwizrd 2 hours ago

    This is coming at the perfect time! I was recently trying to decide whether I wanted to implement a model in Stan or Pyro/Numpyro, and I've been eyeing implementing in JAX. I would love to write a tutorial comparing Stan to Jax.

  • helltone 2 hours ago

    Off topic: I think there's some opportunities for making bayesian inference technology more accessible, and I'd love to chat with other people in this space. Email in my profile.