One of the most important outstanding problems in observational and theoretical astrophysics is to understand the physical origin and evolution of galaxies. Galaxies are gravitationally-bound systems consisting of tens to hundreds of billions of stars, gas, and dust, as well as large amounts of dark matter, which we observe across the entire 14 billion-year history of the universe. Fortunately, sophisticated models exist which allow us to interpret the observed spectral energy distributions of galaxies---in essence, how bright they appear in different parts of the electromagnetic spectrum, particularly in the ultraviolet, optical, and infrared---in terms of their physical properties such as stellar mass and star-formation rate. For example, the stellar mass of a galaxy reveals how efficiently gas has been converted into stars over the evolutionary history of the galaxy, while the star-formation rate indicates the current rate at which new stars are being born, or whether star formation has ceased entirely.
Not surprisingly, the parameter likelihood space which must be explored in order to effectively model observations of galaxies can be very large. In addition, the latest generation of massively multiplexed astrophysical surveys such as the Dark Energy Spectroscopic Instrument (DESI) survey are observing samples of tens of millions of galaxies. Consequently, there is an acute need for massively parallelized, computationally efficient code which can extract astrophysically meaningful constraints from large observational datasets of galaxies.
The open-source Python software package needed to carry out this project is called FastSpecFit (https://fastspecfit.readthedocs.org/en/latest). The code is reasonably well-documented and it has already been run on a high-performance computing system on samples of millions of galaxies observed by DESI. There are two computational bottlenecks, however, which are hampering being able to deploy FastSpecFit at the next scale, both in terms of input sample size and complexity of the underlying astrophysical models. These bottlenecks involve non-negative least-squares (NNLS) and non-linear least-squares fitting, both of which are currently being done using the CPU-optimized SciPy library.
With these issues in mind, the goal of this project is to port the computational "heart" of FastSpecFit to GPUs. We propose using JAX (https://jax.readthedocs.io/en/latest), which uses automatic (or computational) differentiation for optimization. Specifically, the open-source project JAXopt (https://jaxopt.github.io/stable) includes well-tested algorithms for solving a wide range of both linear and non-linear constrained optimization problems using GPU-accelerated, automatic differentiation. After testing these algorithms using simple (simulated) datasets, we will then implement an optional GPU version of FastSpecFit, and ultimately test it on actual DESI data.
Project Information Subsection
This project includes three major deliverables:
1. Documentation which clearly describes how all software products and their dependencies (particularly JAX and JAXopt) should be installed and run, both with and without GPUs.
2. Executable, well-documented code which solves both simulated and real-data bounded non-linear least-squares problems.
3. Comparisons (via benchmarking runs) of existing CPU (e.g., scipy.optimize) and GPU/JAX implementations of the identical problems.
{Empty}
Samyak (Sam) Tuladhar (sd10tula@siena.edu) is a sophomore undergraduate physics major at Siena College and he has both the interest and technical background needed to undertake this project.
{Empty}
Some hands-on experience
{Empty}
Siena College
Department of Physics and Astronomy 515 Loudon Rd Loudonville, New York. 12211
CR-Rensselaer Polytechnic Institute
12/01/2023
No
Already behind3Start date is flexible
6
{Empty}
01/05/2024
{Empty}
05/17/2024
Milestone Title: Complete JAX and JAXOpt Tutorials Milestone Description: Gain familiarity with JAX and JAXOpt by completing several of the tutorials at https://jax.readthedocs.io/en/latest/advanced_guide.html and https://jaxopt.github.io/stable/notebooks/index.html. Completion Date Goal: 2024-01-15
Milestone Title: Generate and model synthetic data Milestone Description: Generate synthetic data (a simple emission-line spectrum with noise), code up the objective function, and optimize its performance using JAX/JAXopt. Completion Date Goal: 2024-03-01
Milestone Title: Modify FastSpecFit to model real data Milestone Description: Add a "GPU/JAX mode" to FastSpecFit and use it to model at least one real galaxy spectrum. Completion Date Goal: 2024-04-15
Milestone Title: Benchmarking and documentation Milestone Description: Carry out benchmarking results on a larger set of spectra and finalize all documentation. Completion Date Goal: 2024-06-01
If successful, I anticipate describing the proposed work and its outcomes in a larger publication which will most likely be submitted to The Astrophysical Journal, one of the top astrophysical journals in the world. Alternatively, depending on the interests of the student, we could prepare a shorter, more technical paper and submit it to a GPU/HPC computing journal (TBD).
The student will learn how deploying GPUs on HPC systems can lead to significant improvements in computing speed, and how those speed-ups directly improve our ability to do science with large astronomical datasets. The student will also improve their Python programming skills and learn how to clearly document and communicate their results to collaborators with a wide range of technical backgrounds.
{Empty}
JAX and JAXOpt are powerful tools for a range of applications in scientific computing, machine learning, artificial intelligence, and much more. The Cyberteam will gain documentation and example code which demonstrates how these codes can be deployed on GPUs on HPCs, and benchmarked, well-documented code which illustrates how that code can be applied to solve one specific class of astrophysics problems.
We will need access to a multi-node GPU system and a modern software architecture with an isolated software environment where all the code dependencies can be installed (Python, JAX, etc.).