"""Functional Virtual Casing API with differentiable geometry inputs."""
from __future__ import annotations
import math
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from .utils import autotune_chunk_sizes
from .surface_ops import (
complete_vec_field,
resample,
rotate_toroidal,
grad2d,
surf_normal_area_elem,
dot_prod,
cross_prod,
)
from .integrals import (
laplace_fxd_u_eval_singular,
laplace_fxd_u_eval_vec_singular,
laplace_fxd2_u_eval_singular,
laplace_fxd2_u_eval_vec_singular,
laplace_fxd_u_eval,
laplace_fxd2_u_eval,
laplace_fxd2_u_eval_vec,
biotsavart_fx_u_eval,
computeB_offsurface_adaptive,
_offsurface_adapt_grid,
_build_patch_indices,
_surface_cond,
select_patch_dim,
)
[docs]
@dataclass(frozen=True)
class FunctionalSetup:
"""Static quadrature setup for functional API."""
nfp: int
nfp_eff: int
half_period: bool
surf_nt: int
surf_np: int
src_nt: int
src_np: int
trg_nt: int
trg_np: int
quad_nt: int
quad_np: int
patch_dim0: int
patch_idx: jnp.ndarray
orient: float
def _resolve_chunk_sizes(op: str, chunk_size, target_chunk_size, *, nsrc: int, ntrg: int):
chunk_auto = chunk_size is None or (isinstance(chunk_size, str) and chunk_size.lower() == "auto")
target_auto = isinstance(target_chunk_size, str) and target_chunk_size.lower() == "auto"
if chunk_auto:
src_auto, trg_auto = autotune_chunk_sizes(op, nsrc, ntrg)
chunk_size = src_auto
if target_auto:
target_chunk_size = trg_auto
else:
chunk_size = int(chunk_size)
if target_auto:
_, trg_auto = autotune_chunk_sizes(op, nsrc, ntrg)
target_chunk_size = trg_auto
if target_chunk_size is not None and not isinstance(target_chunk_size, str):
target_chunk_size = int(target_chunk_size)
return chunk_size, target_chunk_size
def _resolve_pou_dtype(pou_dtype, value_dtype):
if pou_dtype is None:
return None
if isinstance(pou_dtype, str):
if pou_dtype.lower() == "auto":
return jnp.float32 if value_dtype == jnp.float64 else value_dtype
return jnp.dtype(pou_dtype)
return jnp.dtype(pou_dtype)
def _resolve_patch_dtype(patch_dtype, value_dtype):
if patch_dtype is None:
return None
if isinstance(patch_dtype, str):
if patch_dtype.lower() == "auto":
return jnp.float32 if value_dtype == jnp.float64 else value_dtype
return jnp.dtype(patch_dtype)
return jnp.dtype(patch_dtype)
[docs]
def build_surface_coord(X, nfp: int, half_period: bool, surf_nt: int, surf_np: int, trg_nt: int):
"""Build full-field-period surface coordinates from base grid."""
X = jnp.asarray(X).reshape((3, surf_nt, surf_np))
if half_period:
X0 = complete_vec_field(
X,
True,
half_period,
nfp,
surf_nt,
surf_np,
-math.pi / (nfp * surf_nt * 2),
)
X1 = resample(X0, nfp * 2 * surf_nt, surf_np, nfp * 2 * (surf_nt + 1), surf_np)
surface_coord = rotate_toroidal(
X1,
nfp * 2 * (surf_nt + 1),
surf_np,
math.pi / (nfp * trg_nt * 2),
)
nfp_eff = nfp * 2
else:
surface_coord = complete_vec_field(X, True, half_period, nfp, surf_nt, surf_np, 0.0)
nfp_eff = nfp
return surface_coord, int(nfp_eff)
[docs]
def build_quad_setup(surface_coord, quad_nt: int, quad_np: int, *, orient: float | None = None):
"""Compute quadrature coordinates, derivatives, and normals."""
surf_nt_full = int(surface_coord.shape[1])
surf_np_full = int(surface_coord.shape[2])
quad_coord = resample(surface_coord, surf_nt_full, surf_np_full, quad_nt, quad_np)
dX = grad2d(quad_coord, quad_nt, quad_np)
normal, area_elem, orient0 = surf_normal_area_elem(
dX, quad_coord, return_orientation=True
)
if orient is None:
orient = jax.lax.stop_gradient(orient0)
else:
normal = normal * (orient / orient0)
return quad_coord, dX, normal, area_elem, orient
[docs]
def build_patch_idx(quad_nt: int, quad_np: int, trg_nt: int, trg_np: int, nfp_eff: int, patch_dim0: int):
"""Build patch indices for singular quadrature."""
skip_nt = quad_nt // (nfp_eff * trg_nt)
skip_np = quad_np // trg_np
t_idx = jnp.arange(trg_nt) * skip_nt
p_idx = jnp.arange(trg_np) * skip_np
tt, pp = jnp.meshgrid(t_idx, p_idx, indexing="ij")
return _build_patch_indices(
tt.reshape(-1),
pp.reshape(-1),
quad_nt,
quad_np,
patch_dim0,
)
[docs]
def select_patch_dim_from_geom(dX, quad_nt: int, quad_np: int, digits: int):
"""Select patch_dim0 using surface condition (non-differentiable)."""
cond = _surface_cond(dX, quad_nt, quad_np)
return select_patch_dim(int(digits), float(cond))
[docs]
def target_surface_normal(
X,
*,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
trg_nt: int,
trg_np: int,
orient: float | None = None,
):
"""Return unit normals on the virtual-casing target grid."""
surface_coord, nfp_eff = build_surface_coord(X, nfp, half_period, surf_nt, surf_np, trg_nt)
surf_nt_full = int(surface_coord.shape[1])
surf_np_full = int(surface_coord.shape[2])
trg_coord = resample(
surface_coord,
surf_nt_full,
surf_np_full,
nfp_eff * trg_nt,
trg_np,
)
dX = grad2d(trg_coord, nfp_eff * trg_nt, trg_np)
normal, _, orient0 = surf_normal_area_elem(dX, trg_coord, return_orientation=True)
if orient is not None:
normal = normal * (orient / orient0)
return normal[:, :trg_nt, :]
[docs]
def prepare_functional_setup(
X,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
orient: float | None = None,
):
"""Prepare static quadrature setup for functional API.
This helper is intended to be called outside autodiff; it uses
non-differentiable logic to choose patch sizes if not provided.
"""
surface_coord, nfp_eff = build_surface_coord(X, nfp, half_period, surf_nt, surf_np, trg_nt)
quad_coord, dX, normal, _, orient = build_quad_setup(
surface_coord, quad_nt, quad_np, orient=orient
)
if patch_dim0 is None:
patch_dim0 = select_patch_dim_from_geom(dX, quad_nt, quad_np, digits)
patch_idx = build_patch_idx(quad_nt, quad_np, trg_nt, trg_np, nfp_eff, patch_dim0)
return FunctionalSetup(
nfp=int(nfp),
nfp_eff=int(nfp_eff),
half_period=bool(half_period),
surf_nt=int(surf_nt),
surf_np=int(surf_np),
src_nt=int(src_nt),
src_np=int(src_np),
trg_nt=int(trg_nt),
trg_np=int(trg_np),
quad_nt=int(quad_nt),
quad_np=int(quad_np),
patch_dim0=int(patch_dim0),
patch_idx=patch_idx,
orient=float(orient),
)
def _compute_dtheta(nfp: int, half_period: bool, trg_nt: int, src_nt: int):
if not half_period:
return 0.0
return math.pi * (1.0 / (nfp * trg_nt * 2) - 1.0 / (nfp * src_nt * 2))
def _complete_b0(B0, nfp: int, half_period: bool, src_nt: int, src_np: int, dtheta: float):
B0 = jnp.asarray(B0).reshape((3, src_nt, src_np))
return complete_vec_field(B0, False, half_period, nfp, src_nt, src_np, dtheta)
def _compute_B_signed(
X,
B0,
*,
sign: float,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
X_trg=None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
surface_coord, nfp_eff = build_surface_coord(X, nfp, half_period, surf_nt, surf_np, trg_nt)
quad_coord, dX, normal, _, orient = build_quad_setup(
surface_coord, quad_nt, quad_np, orient=orient
)
if remat is None:
remat = False
value_dtype = jnp.asarray(B0).dtype
pou_dtype = _resolve_pou_dtype(pou_dtype, value_dtype)
patch_dtype = _resolve_patch_dtype(patch_dtype, value_dtype)
nsrc = quad_nt * quad_np
if X_trg is None:
ntrg = trg_nt * trg_np
else:
X_trg_arr = jnp.asarray(X_trg)
if X_trg_arr.ndim == 3:
ntrg = X_trg_arr.shape[1] * X_trg_arr.shape[2]
elif X_trg_arr.ndim == 2:
ntrg = X_trg_arr.shape[1]
else:
raise ValueError("X_trg must have shape (3, nt, np) or (3, ntrg)")
chunk_size, target_chunk_size = _resolve_chunk_sizes(
"b", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
if patch_dim0 is None:
patch_dim0 = select_patch_dim_from_geom(dX, quad_nt, quad_np, digits)
if patch_idx is None:
patch_idx = build_patch_idx(quad_nt, quad_np, trg_nt, trg_np, nfp_eff, patch_dim0)
dtheta = _compute_dtheta(nfp, half_period, trg_nt, src_nt)
B0_complete = _complete_b0(B0, nfp, half_period, src_nt, src_np, dtheta)
B_quad = resample(B0_complete, nfp_eff * src_nt, src_np, quad_nt, quad_np)
J = cross_prod(normal, B_quad)
BdotN = dot_prod(B_quad, normal)
gradG_J = laplace_fxd_u_eval_vec_singular(
quad_coord,
dX,
J,
trg_nt,
trg_np,
nfp_eff,
X_trg=X_trg,
digits=digits,
patch_dim0=patch_dim0,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, trg_nt, trg_np))
gradG_BdotN = laplace_fxd_u_eval_singular(
quad_coord,
dX,
BdotN,
trg_nt,
trg_np,
nfp_eff,
X_trg=X_trg,
digits=digits,
patch_dim0=patch_dim0,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradG_BdotN = jnp.asarray(gradG_BdotN).reshape((3, trg_nt, trg_np))
B_on_trg = resample(B0_complete, nfp_eff * src_nt, src_np, nfp_eff * trg_nt, trg_np)
B_on = B_on_trg[:, :trg_nt, :]
Bvc = jnp.zeros((3, trg_nt, trg_np), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
Bvc = Bvc.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
return sign * (Bvc + gradG_BdotN) + 0.5 * B_on
[docs]
def compute_external_B_functional(
X,
B0,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
X_trg=None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Compute Bext with surface coordinates as differentiable inputs."""
return _compute_B_signed(
X,
B0,
sign=1.0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
X_trg=X_trg,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
[docs]
def compute_external_B_normal_functional(
X,
B0,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Compute on-surface Bext dot n with differentiable geometry inputs."""
Bext = compute_external_B_functional(
X,
B0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
normal = target_surface_normal(
X,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
trg_nt=trg_nt,
trg_np=trg_np,
orient=orient,
)
return dot_prod(Bext, normal)
[docs]
def compute_external_B_jvp_columns_functional(
X,
B0,
X_tangents,
B0_tangents=None,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Return Bext and multiple forward-mode tangent columns.
``X_tangents`` has shape ``(ncols, 3, surf_nt, surf_np)``.
``B0_tangents`` has shape ``(ncols, 3, src_nt, src_np)``; when omitted
the magnetic-field input is held fixed. The returned tangent array has
shape ``(ncols, 3, trg_nt, trg_np)``.
"""
X_tangents = jnp.asarray(X_tangents)
ncols = int(X_tangents.shape[0])
if B0_tangents is None:
B0_tangents = jnp.zeros((ncols,) + tuple(jnp.asarray(B0).shape), dtype=X_tangents.dtype)
else:
B0_tangents = jnp.asarray(B0_tangents)
def external_field(x, b0):
return compute_external_B_functional(
x,
b0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
Bext, field_linear = jax.linearize(external_field, X, B0)
columns = jax.vmap(field_linear)(X_tangents, B0_tangents)
return Bext, columns
[docs]
def compute_external_B_normal_jvp_columns_functional(
X,
B0,
X_tangents,
B0_tangents=None,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Return Bext dot n and multiple forward-mode tangent columns.
``X_tangents`` has shape ``(ncols, 3, surf_nt, surf_np)``.
``B0_tangents`` has shape ``(ncols, 3, src_nt, src_np)``; when omitted
the magnetic-field input is held fixed. The returned tangent array has
shape ``(ncols, trg_nt, trg_np)``.
"""
X_tangents = jnp.asarray(X_tangents)
ncols = int(X_tangents.shape[0])
if B0_tangents is None:
B0_tangents = jnp.zeros((ncols,) + tuple(jnp.asarray(B0).shape), dtype=X_tangents.dtype)
else:
B0_tangents = jnp.asarray(B0_tangents)
def normal_field(x, b0):
return compute_external_B_normal_functional(
x,
b0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
Bnormal, normal_linear = jax.linearize(normal_field, X, B0)
columns = jax.vmap(normal_linear)(X_tangents, B0_tangents)
return Bnormal, columns
[docs]
def compute_internal_B_functional(
X,
B0,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
X_trg=None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Compute Bint with surface coordinates as differentiable inputs."""
return _compute_B_signed(
X,
B0,
sign=-1.0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
X_trg=X_trg,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
def _compute_gradB_signed(
X,
B0,
*,
sign: float,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
hedgehog_order: int = 8,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
surface_coord, nfp_eff = build_surface_coord(X, nfp, half_period, surf_nt, surf_np, trg_nt)
quad_coord, dX, normal, _, orient = build_quad_setup(
surface_coord, quad_nt, quad_np, orient=orient
)
if remat is None:
remat = True
value_dtype = jnp.asarray(B0).dtype
pou_dtype = _resolve_pou_dtype(pou_dtype, value_dtype)
patch_dtype = _resolve_patch_dtype(patch_dtype, value_dtype)
nsrc = quad_nt * quad_np
ntrg = trg_nt * trg_np
chunk_size, target_chunk_size = _resolve_chunk_sizes(
"gradb", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
if patch_dim0 is None:
patch_dim0 = select_patch_dim_from_geom(dX, quad_nt, quad_np, digits)
if patch_idx is None:
patch_idx = build_patch_idx(quad_nt, quad_np, trg_nt, trg_np, nfp_eff, patch_dim0)
dtheta = _compute_dtheta(nfp, half_period, trg_nt, src_nt)
B0_complete = _complete_b0(B0, nfp, half_period, src_nt, src_np, dtheta)
B_quad = resample(B0_complete, nfp_eff * src_nt, src_np, quad_nt, quad_np)
J = cross_prod(normal, B_quad)
BdotN = dot_prod(B_quad, normal)
gradG_J = laplace_fxd2_u_eval_vec_singular(
quad_coord,
dX,
J,
trg_nt,
trg_np,
nfp_eff,
digits=digits,
patch_dim0=patch_dim0,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, 3, trg_nt, trg_np))
gradgradG_BdotN = laplace_fxd2_u_eval_singular(
quad_coord,
dX,
BdotN,
trg_nt,
trg_np,
nfp_eff,
digits=digits,
patch_dim0=patch_dim0,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradgradG_BdotN = jnp.asarray(gradgradG_BdotN).reshape((3, 3, trg_nt, trg_np))
gradBvc = jnp.zeros((3, 3, trg_nt, trg_np), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
gradBvc = gradBvc.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
return (gradBvc + gradgradG_BdotN) * sign
[docs]
def compute_external_gradB_functional(
X,
B0,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
hedgehog_order: int = 8,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Compute GradBext with surface coordinates as differentiable inputs."""
return _compute_gradB_signed(
X,
B0,
sign=1.0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
[docs]
def compute_internal_gradB_functional(
X,
B0,
*,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
quad_nt: int,
quad_np: int,
patch_dim0: int | None = None,
patch_idx=None,
orient: float | None = None,
hedgehog_order: int = 8,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
):
"""Compute GradBint with surface coordinates as differentiable inputs."""
return _compute_gradB_signed(
X,
B0,
sign=-1.0,
digits=digits,
nfp=nfp,
half_period=half_period,
surf_nt=surf_nt,
surf_np=surf_np,
src_nt=src_nt,
src_np=src_np,
trg_nt=trg_nt,
trg_np=trg_np,
quad_nt=quad_nt,
quad_np=quad_np,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
orient=orient,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
[docs]
def compute_external_B_offsurf_functional(
X,
B0,
*,
X_trg,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
max_Nt: int = -1,
max_Np: int = -1,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
adaptive: bool = True,
):
"""Compute off-surface Bext with differentiable geometry inputs."""
surface_coord, nfp_eff = build_surface_coord(X, nfp, half_period, surf_nt, surf_np, trg_nt)
surf_nt_full = int(surface_coord.shape[1])
surf_np_full = int(surface_coord.shape[2])
patch_dim = 13
base_nt = max(nfp_eff * src_nt, surf_nt_full, patch_dim)
base_np = max(src_np, surf_np_full, patch_dim)
X_src = resample(surface_coord, surf_nt_full, surf_np_full, base_nt, base_np)
dX = grad2d(X_src, base_nt, base_np)
normal, _ = surf_normal_area_elem(dX, X_src)
dtheta = _compute_dtheta(nfp, half_period, trg_nt, src_nt)
B0_complete = _complete_b0(B0, nfp, half_period, src_nt, src_np, dtheta)
B_quad = resample(B0_complete, nfp_eff * src_nt, src_np, base_nt, base_np)
J = cross_prod(normal, B_quad)
BdotN = dot_prod(B_quad, normal)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = _resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
if adaptive:
out = computeB_offsurface_adaptive(
X_src,
BdotN,
J,
X_trg,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
ext=True,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
else:
area_elem = surf_normal_area_elem(dX, X_src)[1]
gradG = laplace_fxd_u_eval(
X_src, X_trg, BdotN, area_elem, chunk_size=chunk_size, target_chunk_size=target_chunk_size
)
bs = biotsavart_fx_u_eval(
X_src, X_trg, J, area_elem, chunk_size=chunk_size, target_chunk_size=target_chunk_size
)
out = gradG - bs
return out
[docs]
def compute_external_gradB_offsurf_functional(
X,
B0,
*,
X_trg,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
max_Nt: int = -1,
max_Np: int = -1,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
adaptive: bool = False,
):
"""Compute off-surface GradBext with differentiable geometry inputs."""
surface_coord, nfp_eff = build_surface_coord(X, nfp, half_period, surf_nt, surf_np, trg_nt)
surf_nt_full = int(surface_coord.shape[1])
surf_np_full = int(surface_coord.shape[2])
patch_dim = 13
base_nt = max(nfp_eff * src_nt, surf_nt_full, patch_dim)
base_np = max(src_np, surf_np_full, patch_dim)
X_src = resample(surface_coord, surf_nt_full, surf_np_full, base_nt, base_np)
dX = grad2d(X_src, base_nt, base_np)
normal, area_elem = surf_normal_area_elem(dX, X_src)
dtheta = _compute_dtheta(nfp, half_period, trg_nt, src_nt)
B0_complete = _complete_b0(B0, nfp, half_period, src_nt, src_np, dtheta)
B_quad = resample(B0_complete, nfp_eff * src_nt, src_np, base_nt, base_np)
J = cross_prod(normal, B_quad)
BdotN = dot_prod(B_quad, normal)
X_trg = jnp.asarray(X_trg)
X_trg_flat = X_trg.reshape((3, -1)) if X_trg.ndim == 3 else X_trg
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg_flat.shape[1]
chunk_size, target_chunk_size = _resolve_chunk_sizes(
"gradb_off", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
if adaptive:
X_src, BdotN, J, area_elem = _offsurface_adapt_grid(
X_src,
BdotN,
J,
X_trg_flat,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradG_J = laplace_fxd2_u_eval_vec(
X_src,
X_trg_flat,
J,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, 3, X_trg_flat.shape[1]))
gradgradG_BdotN = laplace_fxd2_u_eval(
X_src,
X_trg_flat,
BdotN,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradgradG_BdotN = jnp.asarray(gradgradG_BdotN).reshape((3, 3, X_trg_flat.shape[1]))
gradB = jnp.zeros((3, 3, X_trg_flat.shape[1]), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
gradB = gradB.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
gradB = gradB + gradgradG_BdotN
if X_trg.ndim == 3:
return gradB.reshape((3, 3, X_trg.shape[1], X_trg.shape[2]))
return gradB