fitting¶
Gradient-based parameter optimisation via JAX autodiff and optax.
jbubble.fitting.FitResult(params, loss_history, result)
dataclass
¶
Output of :func:fit_parameters.
Attributes:
| Name | Type | Description |
|---|---|---|
params |
PyTree
|
Fitted parameter values — same type and structure as |
loss_history |
(Array, shape(n_steps))
|
Loss value recorded at each optimisation step. |
result |
SimulationResult
|
Final forward simulation run with the fitted parameters. |
jbubble.fitting.fit_parameters(make_model, params0, *, save_spec, t_max=None, loss_fn, optimizer, n_steps=200, config=None, adjoint=None, step_callback=None, log_every=25)
¶
Fit model parameters by differentiating through the ODE integration.
make_model maps the current params to an (eom, pulse) pair,
so anything that affects either the bubble physics or the acoustic drive
can be optimised jointly from a single params pytree. Examples::
# Single parameter as a dict
fit_parameters(
make_model=lambda p: (make_eom(p), my_pulse),
params0={'kappa_s': 2.4e-9},
...
)
# Joint frequency + radius optimisation
fit_parameters(
make_model=lambda p: (make_eom(p['R0']), make_pulse(p['freq'])),
params0={'R0': 5e-6, 'freq': 1e6},
...
)
# Heterogeneous — neural surface tension + scalar kappa_s + neural pulse
fit_parameters(
make_model=lambda p: (
KellerMiksis(shell=LipidShell(sigma=p.sigma, kappa_s=p.kappa_s), ...),
p.pulse,
),
params0=LearnedParams(sigma=NeuralProperty(...), kappa_s=2.4e-9, pulse=NeuralPulse(...)),
...
)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
make_model
|
callable
|
|
required |
params0
|
PyTree
|
Initial parameter values — any JAX-compatible pytree (scalar,
array, dict, tuple, or |
required |
save_spec
|
SaveSpec
|
Output sampling specification. |
required |
t_max
|
float
|
Integration end time [s]. |
None
|
loss_fn
|
callable
|
|
required |
optimizer
|
GradientTransformation
|
Gradient-based optimiser, e.g. |
required |
n_steps
|
int
|
Number of optimisation steps. Default: 200. |
200
|
config
|
SolverConfig
|
ODE solver settings. Default: Tsit5 with PID(rtol=1e-4, atol=1e-8), 50 000 max steps. |
None
|
adjoint
|
AbstractAdjoint
|
Adjoint method. Default: |
None
|
step_callback
|
callable
|
|
None
|
log_every
|
int
|
Print loss and current parameters every this many steps. Set to 0 to disable logging. Default: 25. |
25
|
Returns:
| Type | Description |
|---|---|
FitResult
|
Fitted parameters, full loss history, and a final
:class: |
Source code in jbubble/fitting.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | |
How it works¶
fit_parameters traces through the entire pipeline:
JAX's autodiff (via diffrax's backpropagation-through-time) computes the gradient of the loss with respect to params, which is then fed to an optax optimiser.
The function automatically partitions params into:
- Differentiable leaves: jax.Array scalars and arrays — these receive gradients.
- Static leaves: Python scalars, strings, non-array objects — these are held fixed.
This means integer hyperparameters (e.g. number of cycles) can live inside params without causing errors.
Memory considerations¶
For long simulations with many steps, backpropagation through the ODE solve requires storing intermediate solver states. If memory is a concern, use the recursive checkpoint adjoint:
import diffrax
fit_result = fit_parameters(
...,
adjoint=diffrax.RecursiveCheckpointAdjoint(checkpoints=100),
)
This trades computation for memory: it recomputes intermediate states during the backward pass rather than storing all of them.
Convergence monitoring¶
The step_callback is called outside JIT after each gradient step, so it can perform arbitrary Python-side operations (plotting, logging, early stopping):
import matplotlib.pyplot as plt
losses = []
params_history = []
def callback(step, params, loss):
losses.append(float(loss))
params_history.append(params)
if step % 50 == 0:
plt.clf()
plt.plot(losses)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.pause(0.01)
fit_result = fit_parameters(..., step_callback=callback)