Functional API
The functional API makes the surface coordinates X a JAX input so shape
derivatives are native and end-to-end differentiable. This is essential for
single-stage optimization pipelines where geometry updates are part of the
computational graph.
Key idea
The class-based VirtualCasingJAX caches quadrature geometry in a mutable
object. That is great for throughput, but it separates the geometry from the
autodiff graph. The functional API instead rebuilds the geometry from X on
each call, while keeping all discrete quadrature choices static so JAX can
trace and compile.
Core functions live in virtual_casing_jax.functional:
build_surface_coord: reproduce the BIEST-style full-field-period surface coordinates from a single-period input grid.build_quad_setup: resample to quadrature grid, compute derivatives and normals.build_patch_idx: precompute the patch indices for singular quadrature.compute_external_B_functional/compute_internal_B_functional: on-surface virtual casing fields with singular correction.compute_external_gradB_functional/compute_internal_gradB_functional: on-surface gradients.
Because the patch size and quadrature sizes are discrete decisions, they must
remain fixed during differentiation. The helper prepare_functional_setup
can be used to compute those values outside autodiff, and then passed into the
functional calls as static arguments.
Example
import jax
import jax.numpy as jnp
from virtual_casing_jax.functional import (
prepare_functional_setup,
compute_external_B_functional,
)
# X: surface coordinates (3, surf_nt, surf_np)
# B0: total field on source grid (3, src_nt, src_np)
setup = prepare_functional_setup(
X,
digits=6,
nfp=1,
half_period=False,
surf_nt=16,
surf_np=16,
src_nt=16,
src_np=16,
trg_nt=16,
trg_np=16,
quad_nt=24,
quad_np=24,
)
def scalar_objective(xsurf):
b = compute_external_B_functional(
xsurf,
B0,
digits=6,
nfp=setup.nfp,
half_period=setup.half_period,
surf_nt=setup.surf_nt,
surf_np=setup.surf_np,
src_nt=setup.src_nt,
src_np=setup.src_np,
trg_nt=setup.trg_nt,
trg_np=setup.trg_np,
quad_nt=setup.quad_nt,
quad_np=setup.quad_np,
patch_dim0=setup.patch_dim0,
patch_idx=setup.patch_idx,
orient=setup.orient,
)
return jnp.sum(b * b)
grad_x = jax.grad(scalar_objective)(X)
Guidelines
Keep
quad_nt,quad_np, andpatch_dim0static during differentiation. If these change, the discretization changes and autodiff is not meaningful.Use
prepare_functional_setupoutside the gradient context to select appropriate quadrature sizes.orientis treated as a fixed sign (±1) computed from the input geometry. This is a topological choice that should not be differentiated.