학술논문

BlackJAX: Composable Bayesian inference in JAX
Document Type
Working Paper
Source
Subject
Computer Science - Mathematical Software
Computer Science - Machine Learning
Statistics - Computation
Statistics - Machine Learning
Language
Abstract
BlackJAX is a library implementing sampling and variational inference algorithms commonly used in Bayesian computation. It is designed for ease of use, speed, and modularity by taking a functional approach to the algorithms' implementation. BlackJAX is written in Python, using JAX to compile and run NumpPy-like samplers and variational methods on CPUs, GPUs, and TPUs. The library integrates well with probabilistic programming languages by working directly with the (un-normalized) target log density function. BlackJAX is intended as a collection of low-level, composable implementations of basic statistical 'atoms' that can be combined to perform well-defined Bayesian inference, but also provides high-level routines for ease of use. It is designed for users who need cutting-edge methods, researchers who want to create complex sampling methods, and people who want to learn how these work.
Comment: Companion paper for the library https://github.com/blackjax-devs/blackjax Update: minor changes and updated the list of authors to include technical contributors