Functional API

The functional API makes the surface coordinates X a JAX input so shape derivatives are native and end-to-end differentiable. This is essential for single-stage optimization pipelines where geometry updates are part of the computational graph.

Key idea

The class-based VirtualCasingJAX caches quadrature geometry in a mutable object. That is great for throughput, but it separates the geometry from the autodiff graph. The functional API instead rebuilds the geometry from X on each call, while keeping all discrete quadrature choices static so JAX can trace and compile.

Core functions live in virtual_casing_jax.functional:

  • build_surface_coord: reproduce the BIEST-style full-field-period surface coordinates from a single-period input grid.

  • build_quad_setup: resample to quadrature grid, compute derivatives and normals.

  • build_patch_idx: precompute the patch indices for singular quadrature.

  • compute_external_B_functional / compute_internal_B_functional: on-surface virtual casing fields with singular correction.

  • compute_external_gradB_functional / compute_internal_gradB_functional: on-surface gradients.

Because the patch size and quadrature sizes are discrete decisions, they must remain fixed during differentiation. The helper prepare_functional_setup can be used to compute those values outside autodiff, and then passed into the functional calls as static arguments.

Example

import jax
import jax.numpy as jnp
from virtual_casing_jax.functional import (
    prepare_functional_setup,
    compute_external_B_functional,
)

# X: surface coordinates (3, surf_nt, surf_np)
# B0: total field on source grid (3, src_nt, src_np)
setup = prepare_functional_setup(
    X,
    digits=6,
    nfp=1,
    half_period=False,
    surf_nt=16,
    surf_np=16,
    src_nt=16,
    src_np=16,
    trg_nt=16,
    trg_np=16,
    quad_nt=24,
    quad_np=24,
)

def scalar_objective(xsurf):
    b = compute_external_B_functional(
        xsurf,
        B0,
        digits=6,
        nfp=setup.nfp,
        half_period=setup.half_period,
        surf_nt=setup.surf_nt,
        surf_np=setup.surf_np,
        src_nt=setup.src_nt,
        src_np=setup.src_np,
        trg_nt=setup.trg_nt,
        trg_np=setup.trg_np,
        quad_nt=setup.quad_nt,
        quad_np=setup.quad_np,
        patch_dim0=setup.patch_dim0,
        patch_idx=setup.patch_idx,
        orient=setup.orient,
    )
    return jnp.sum(b * b)

grad_x = jax.grad(scalar_objective)(X)

Guidelines

  • Keep quad_nt, quad_np, and patch_dim0 static during differentiation. If these change, the discretization changes and autodiff is not meaningful.

  • Use prepare_functional_setup outside the gradient context to select appropriate quadrature sizes.

  • orient is treated as a fixed sign (±1) computed from the input geometry. This is a topological choice that should not be differentiated.