Skip to content

metrics

Differentiable loss functions for parameter fitting and model comparison.

All functions in this module operate on plain jax.Array arguments and are fully differentiable through jax.grad. They are designed to be passed as (or called within) the loss_fn argument of fit_parameters.

from jbubble.metrics import mse_radius, normalised_mse_radius
from jbubble.metrics import mse_emission, normalised_mse_emission
from jbubble.metrics import peak_expansion, peak_expansion_error

jbubble.metrics.mse_radius(r_sim, r_target)

Mean squared error between simulated and target radii [m²].

Parameters:

Name Type Description Default
r_sim (Array, shape(N))

Simulated radius trajectory [m].

required
r_target (Array, shape(N))

Target radius trajectory [m].

required
Source code in jbubble/metrics.py
def mse_radius(r_sim: jax.Array, r_target: jax.Array) -> jax.Array:
    """Mean squared error between simulated and target radii [m²].

    Parameters
    ----------
    r_sim : jax.Array, shape (N,)
        Simulated radius trajectory [m].
    r_target : jax.Array, shape (N,)
        Target radius trajectory [m].
    """
    return jnp.mean((r_sim - r_target) ** 2)

jbubble.metrics.normalised_mse_radius(r_sim, r_target, R0)

Normalised mean squared radius error (dimensionless).

Equivalent to mse_radius(r_sim / R0, r_target / R0). The R0 normalisation makes the loss dimensionless and ~O(1) regardless of bubble size, which improves optimiser conditioning.

Parameters:

Name Type Description Default
r_sim (Array, shape(N))

Simulated radius trajectory [m].

required
r_target (Array, shape(N))

Target radius trajectory [m].

required
R0 float or Array

Equilibrium radius used for normalisation [m].

required
Source code in jbubble/metrics.py
def normalised_mse_radius(
    r_sim: jax.Array,
    r_target: jax.Array,
    R0: ArrayLike,
) -> jax.Array:
    """Normalised mean squared radius error (dimensionless).

    Equivalent to ``mse_radius(r_sim / R0, r_target / R0)``.  The R0
    normalisation makes the loss dimensionless and ~O(1) regardless of
    bubble size, which improves optimiser conditioning.

    Parameters
    ----------
    r_sim : jax.Array, shape (N,)
        Simulated radius trajectory [m].
    r_target : jax.Array, shape (N,)
        Target radius trajectory [m].
    R0 : float or jax.Array
        Equilibrium radius used for normalisation [m].
    """
    return jnp.mean(((r_sim - r_target) / R0) ** 2)

jbubble.metrics.peak_expansion(r_sim, R0)

Maximum radial expansion ratio R_max / R0.

Parameters:

Name Type Description Default
r_sim (Array, shape(N))

Simulated radius trajectory [m].

required
R0 float or Array

Equilibrium radius [m].

required
Source code in jbubble/metrics.py
def peak_expansion(r_sim: jax.Array, R0: ArrayLike) -> jax.Array:
    """Maximum radial expansion ratio R_max / R0.

    Parameters
    ----------
    r_sim : jax.Array, shape (N,)
        Simulated radius trajectory [m].
    R0 : float or jax.Array
        Equilibrium radius [m].
    """
    return jnp.max(r_sim) / R0

jbubble.metrics.peak_expansion_error(r_sim, R0, target_expansion)

Squared error in peak expansion ratio.

Useful when only the maximum oscillation amplitude matters rather than the full waveform.

Parameters:

Name Type Description Default
r_sim (Array, shape(N))

Simulated radius trajectory [m].

required
R0 float or Array

Equilibrium radius [m].

required
target_expansion float or Array

Target R_max / R0 value.

required
Source code in jbubble/metrics.py
def peak_expansion_error(
    r_sim: jax.Array,
    R0: ArrayLike,
    target_expansion: ArrayLike,
) -> jax.Array:
    """Squared error in peak expansion ratio.

    Useful when only the maximum oscillation amplitude matters rather
    than the full waveform.

    Parameters
    ----------
    r_sim : jax.Array, shape (N,)
        Simulated radius trajectory [m].
    R0 : float or jax.Array
        Equilibrium radius [m].
    target_expansion : float or jax.Array
        Target R_max / R0 value.
    """
    return (peak_expansion(r_sim, R0) - target_expansion) ** 2

jbubble.metrics.mse_emission(p_sim, p_target)

Mean squared error between simulated and target emission pressure [Pa²].

Parameters:

Name Type Description Default
p_sim (Array, shape(N))

Simulated radiated pressure [Pa].

required
p_target (Array, shape(N))

Target radiated pressure [Pa].

required
Source code in jbubble/metrics.py
def mse_emission(
    p_sim: jax.Array,
    p_target: jax.Array,
) -> jax.Array:
    """Mean squared error between simulated and target emission pressure [Pa²].

    Parameters
    ----------
    p_sim : jax.Array, shape (N,)
        Simulated radiated pressure [Pa].
    p_target : jax.Array, shape (N,)
        Target radiated pressure [Pa].
    """
    return jnp.mean((p_sim - p_target) ** 2)

jbubble.metrics.normalised_mse_emission(p_sim, p_target, p_ref)

Normalised MSE for emission pressure (dimensionless).

Parameters:

Name Type Description Default
p_sim (Array, shape(N))

Simulated radiated pressure [Pa].

required
p_target (Array, shape(N))

Target radiated pressure [Pa].

required
p_ref float or Array

Reference pressure for normalisation [Pa].

required
Source code in jbubble/metrics.py
def normalised_mse_emission(
    p_sim: jax.Array,
    p_target: jax.Array,
    p_ref: ArrayLike,
) -> jax.Array:
    """Normalised MSE for emission pressure (dimensionless).

    Parameters
    ----------
    p_sim : jax.Array, shape (N,)
        Simulated radiated pressure [Pa].
    p_target : jax.Array, shape (N,)
        Target radiated pressure [Pa].
    p_ref : float or jax.Array
        Reference pressure for normalisation [Pa].
    """
    return jnp.mean(((p_sim - p_target) / p_ref) ** 2)

Choosing a loss function

Loss Use when
mse_radius Direct radius fitting; same units as data
normalised_mse_radius Fitting across bubbles of different sizes; scale-invariant
peak_expansion_error Only the amplitude matters, not the full waveform shape
mse_emission Fitting to hydrophone measurements
normalised_mse_emission Fitting to hydrophone data with a known reference level

Normalised losses are generally preferable for gradient-based fitting because they keep the loss \(\mathcal{O}(1)\) regardless of bubble size or driving pressure, making the choice of learning rate more robust.

Custom loss functions

Any differentiable function of SimulationResult fields works as a loss_fn. For example, fitting to the power spectral density:

import jax.numpy as jnp

def psd_loss(result, target_psd):
    fft_r = jnp.abs(jnp.fft.rfft(result.radius)) ** 2
    return jnp.mean((fft_r - target_psd) ** 2)

fit_result = fit_parameters(
    ...,
    loss_fn=lambda result: psd_loss(result, measured_psd),
)