Performance

The virtual casing evaluation is dominated by surface integrals with singular kernels. Performance and memory efficiency are addressed via:

  • 2D blocking over sources and targets (tiling).

  • JIT-compiled kernels with static shapes.

  • Precomputed quadrature tables and interpolation matrices.

  • Auto-tuned chunk sizes per operator (B vs GradB) and backend (CPU/GPU).

  • Optional rematerialization to reduce memory in the backward pass.

  • Mixed precision for POU/patch intermediates with float64 outputs.

Baseline (Direct-Sum) Path

The initial JAX implementation evaluates Laplace FxdU using a direct quadrature with chunking and jax.lax.scan. This avoids materializing the full N_trg x N_src kernel matrix and keeps memory use linear in the chunk size. The baseline is primarily for correctness and parity instrumentation; singular corrections will be layered on later.

Target Blocking (2D Tiling)

The direct-sum kernels now support a second tiling dimension over targets. This avoids large broadcasted temporaries such as [ntrg, nsrc, 3] when ntrg is large. The API exposes target_chunk_size to control the target tile size. When enabled, each tile performs a source scan using jax.lax.scan so accumulation happens inside the kernel loop.

Singular Correction

The singular correction introduces per-target patch work. For now it is implemented in Python with JAX primitives, which is adequate for parity tests but not yet optimized. The next step is to batch patches and use vmap/scan to reduce overhead and enable GPU acceleration.

Rematerialization

The GradB singular correction supports optional rematerialization via jax.checkpoint to trade recomputation for memory. Use remat=True in GradB paths to reduce the size of saved intermediates during autodiff.

Adaptive Off-Surface

Adaptive refinement requires repeated evaluations of a double-layer test. This is currently a Python loop; performance will improve once the refinement is JIT-compiled with static shape schedules.

Off-Surface GradB

The off-surface gradient evaluates second-derivative kernels and is more expensive than the field evaluation. The default path matches the C++ reference and uses the base resampled grid (no adaptive refinement). Enable adaptive=True only when additional accuracy is needed, and use max_Nt/max_Np to cap growth.

CPU and GPU

The same JAX code runs on CPU and GPU. GPU acceleration is achieved via large batched kernel evaluations. The code avoids Python loops in performance-critical paths.

Tips and Tricks

  • Use a fixed set of nphi, ntheta for JIT reuse.

  • Cache POU and interpolation tables for each (PATCH_DIM, RAD_DIM).

  • For parity checks, use float64; for production, use mixed precision with float32 inputs. Enable jax_enable_x64 in tests to match the reference C++ results.

Performance Guide

Chunk Size

The chunk_size parameter controls the source tiling in direct quadrature. Larger chunks improve arithmetic intensity but increase peak memory. target_chunk_size provides a second tiling dimension over targets. For parity tests, chunk_size=1024 and target_chunk_size=auto is a good balance.

Auto tuning is enabled by passing chunk_size="auto" and target_chunk_size="auto" (default in the high-level API). The CPU B and off-surface B heuristics leave small target grids unblocked because the extra target scan can dominate the kernel time. The heuristics can be overridden via environment variables:

  • VCJAX_CHUNK_B / VCJAX_CHUNK_B_SRC / VCJAX_CHUNK_B_TRG

  • VCJAX_CHUNK_BOFF / VCJAX_CHUNK_BOFF_SRC / VCJAX_CHUNK_BOFF_TRG

  • VCJAX_CHUNK_GRADB / VCJAX_CHUNK_GRADB_SRC / VCJAX_CHUNK_GRADB_TRG

Interpolation Blocking

interp_block_size controls blocking of the polar interpolation in the singular correction. interp_block_size="auto" (default) uses a block size of 64 for B and 32 for GradB when the polar grid is large, reducing temporary memory. Set interp_block_size=None to restore the full (unblocked) interpolation.

Target-Scan Fusion

The GradB singular correction can optionally replace the vmap over target points with a lax.scan loop. Enable this with scan_targets=True in compute_external_gradB (or the JIT wrapper). This reduces large broadcasted temporaries by avoiding replication of the interpolation weights across the target batch. The tradeoff is less parallelism per chunk, so this should be enabled only when memory is the bottleneck.

On case_vc_large (CPU HLO), enabling interp_block_size="auto" reduces the largest singular-correction temporaries to ~50 MiB (B) and ~76 MiB (GradB), compared to >150 MiB without interpolation blocking even with patch_dtype="float32".

JIT Caching

The high-level wrappers in VirtualCasingJAX expose JIT-compiled variants (compute_external_B_jit and compute_external_gradB_jit). These cache compiled functions keyed by argument settings. For repeated evaluations with fixed grid sizes, prefer the JIT variants to amortize compilation cost.

For long-running loops, pass donate=True to the JIT wrappers to allow XLA to reuse the input buffers and reduce peak memory.

Batching

Use compute_external_B_batch or compute_external_gradB_batch when evaluating many fields in parallel (e.g., multiple VMEC surfaces or Monte Carlo samples). These functions use vmap to avoid Python loops.

Off-surface Schedule + JIT

The adaptive off-surface refinement loop in the Python API is not JIT-compatible. For JIT-friendly evaluation, use the schedule-based methods with levels="auto" to build a fixed refinement schedule (doubling in Nt/Np up to max_Nt/max_Np). The JIT wrappers (compute_external_B_offsurf_schedule_jit / compute_external_gradB_offsurf_schedule_jit) accept donate=True to reduce peak memory. For custom schedules, use virtual_casing_jax.utils.build_offsurface_levels.

Precision Tradeoffs

Float64 is recommended for parity with the C++ backend. Mixed precision with float32 inputs can provide significant speedups, but requires relaxed tolerances in validation. Keep jax_enable_x64 enabled in CI to maintain reference accuracy.

For singular correction, the POU/polar interpolation tables can be cast to float32 while keeping the final accumulation in float64. Use pou_dtype="auto" or pou_dtype="float32" in high-level calls to enable this optimization. The interpolation weights and patch gather values can be cast independently via patch_dtype="auto" (or "float32"), which reduces the largest temporary tensors while the final outputs remain in the input precision.

On case_vc_large (CPU HLO), patch_dtype="float32" reduces the largest singular-correction gather from ~304 MiB to ~152 MiB for B, and from ~457 MiB to ~228 MiB for GradB.

Precompute Reuse

The polar quadrature tables and interpolation weights are cached via precompute_singular. Patch index maps are cached per quadrature setup inside VirtualCasingJAX to avoid recomputing the patch gather indices on each call. Patch indices are stored as int32 to reduce memory traffic during gathers.

Profiling and Diagnostics

JAX Profiler (TensorBoard)

Use the profiling harness in tools/profile_vc.py to emit a trace that can be opened in TensorBoard:

JAX_ENABLE_X64=1 python tools/profile_vc.py \
  --case case_vc --op B --jit --repeat 5 \
  --chunk-size auto --target-chunk-size auto \
  --trace-dir /tmp/vc_trace

tensorboard --logdir /tmp/vc_trace

The trace includes XLA compilation time, kernel launches, and host-to-device transfer costs. Always call jax.block_until_ready (handled by the script) to ensure timings reflect actual execution.

XLA HLO / MLIR Dumps

To inspect XLA lowering and fusion decisions, enable dump flags:

XLA_FLAGS=\"--xla_dump_to=/tmp/xla --xla_dump_hlo_as_text\" \\
JAX_ENABLE_X64=1 python tools/profile_vc.py --case case_vc --op GradB --jit

The dump directory contains HLO modules and MLIR. These are useful for verifying fusion, identifying large intermediates, and checking precision lowering.

Kernel-Level GPU Profiling

On NVIDIA GPUs, use nsys or nvprof to profile kernel launches:

nsys profile -o /tmp/vc_profile \\
  python tools/profile_vc.py --case case_vc --op B --jit --repeat 10

Pair this with the JAX trace to correlate high-level ops with GPU kernels.

Memory Audits

Memory usage is dominated by chunked kernel evaluation and intermediate arrays during singular correction. Use smaller chunk_size values to reduce peak memory, and profile with multiple chunk sizes to identify the best speed/memory tradeoff. For large runs, consider setting:

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8

to control the allocator footprint on GPU.