"""High-level Virtual Casing routines in JAX."""
from __future__ import annotations
import math
from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
from .utils import autotune_chunk_sizes, build_offsurface_levels
from .surface_ops import (
complete_vec_field,
resample,
rotate_toroidal,
grad2d,
surf_normal_area_elem,
dot_prod,
cross_prod,
)
from .integrals import (
laplace_fxd_u_eval,
laplace_fxd_u_eval_singular,
laplace_fxd_u_eval_vec_singular,
laplace_fxd2_u_eval,
laplace_fxd2_u_eval_vec,
laplace_fxd2_u_eval_singular,
laplace_fxd2_u_eval_vec_singular,
laplace_dx_u_eval_singular,
computeB_offsurface_adaptive,
computeB_offsurface_adaptive_schedule,
computeGradB_offsurface_adaptive_schedule,
_offsurface_adapt_grid,
_build_patch_indices,
_surface_cond,
select_patch_dim,
)
[docs]
@dataclass
class QuadSetup:
quad_nt: int
quad_np: int
quad_coord: jnp.ndarray
dX: jnp.ndarray
normal: jnp.ndarray
orient: float
patch_idx_cache: dict[int, jnp.ndarray] = field(default_factory=dict)
[docs]
class VirtualCasingJAX:
"""JAX mirror of VirtualCasing for external field and GradB."""
def __init__(self):
self._setup = False
self._grad_setup: QuadSetup | None = None
self._b_setup: QuadSetup | None = None
self._jit_cache: dict[tuple, callable] = {}
def _resolve_chunk_sizes(
self,
op: str,
chunk_size,
target_chunk_size,
*,
nsrc: int,
ntrg: int,
):
"""Resolve source/target chunk sizes with optional auto tuning."""
chunk_auto = chunk_size is None or (isinstance(chunk_size, str) and chunk_size.lower() == "auto")
target_auto = isinstance(target_chunk_size, str) and target_chunk_size.lower() == "auto"
if chunk_auto:
src_auto, trg_auto = autotune_chunk_sizes(op, nsrc, ntrg)
chunk_size = src_auto
if target_auto:
target_chunk_size = trg_auto
else:
chunk_size = int(chunk_size)
if target_auto:
_, trg_auto = autotune_chunk_sizes(op, nsrc, ntrg)
target_chunk_size = trg_auto
if target_chunk_size is not None and not isinstance(target_chunk_size, str):
target_chunk_size = int(target_chunk_size)
return chunk_size, target_chunk_size
def _resolve_pou_dtype(self, pou_dtype, value_dtype):
if pou_dtype is None:
return None
if isinstance(pou_dtype, str):
if pou_dtype.lower() == "auto":
return jnp.float32 if value_dtype == jnp.float64 else value_dtype
return jnp.dtype(pou_dtype)
return jnp.dtype(pou_dtype)
def _resolve_patch_dtype(self, patch_dtype, value_dtype):
if patch_dtype is None:
return None
if isinstance(patch_dtype, str):
if patch_dtype.lower() == "auto":
return jnp.float32 if value_dtype == jnp.float64 else value_dtype
return jnp.dtype(patch_dtype)
return jnp.dtype(patch_dtype)
def _resolve_offsurface_levels(
self,
levels,
*,
nt0: int,
np0: int,
max_Nt: int,
max_Np: int,
max_levels: int,
):
if levels is None or (isinstance(levels, str) and levels.lower() == "auto"):
return build_offsurface_levels(nt0, np0, max_Nt=max_Nt, max_Np=max_Np, max_levels=max_levels)
return tuple((int(nt), int(np)) for nt, np in levels)
[docs]
def setup(
self,
digits: int,
nfp: int,
half_period: bool,
surf_nt: int,
surf_np: int,
X,
src_nt: int,
src_np: int,
trg_nt: int,
trg_np: int,
):
X = jnp.asarray(X).reshape((3, surf_nt, surf_np))
if half_period:
X0 = complete_vec_field(
X,
True,
half_period,
nfp,
surf_nt,
surf_np,
-math.pi / (nfp * surf_nt * 2),
)
X1 = resample(X0, nfp * 2 * surf_nt, surf_np, nfp * 2 * (surf_nt + 1), surf_np)
surface_coord = rotate_toroidal(
X1,
nfp * 2 * (surf_nt + 1),
surf_np,
math.pi / (nfp * trg_nt * 2),
)
nfp_eff = nfp * 2
else:
surface_coord = complete_vec_field(
X,
True,
half_period,
nfp,
surf_nt,
surf_np,
0.0,
)
nfp_eff = nfp
self.digits = int(digits)
self.nfp = int(nfp)
self.nfp_eff = int(nfp_eff)
self.half_period = bool(half_period)
self.surf_nt = int(surf_nt)
self.surf_np = int(surf_np)
self.src_nt = int(src_nt)
self.src_np = int(src_np)
self.trg_nt = int(trg_nt)
self.trg_np = int(trg_np)
self.surface_coord = surface_coord
self._setup = True
self._grad_setup = None
self._b_setup = None
def _select_quad_sizes(self, digits: int):
surf_nt_full = int(self.surface_coord.shape[1])
surf_np_full = int(self.surface_coord.shape[2])
src_nt_full = int(self.nfp_eff * self.src_nt)
dX = grad2d(self.surface_coord, surf_nt_full, surf_np_full)
dX_np = jnp.asarray(dX)
xt = dX_np[0]
xp = dX_np[1]
yt = dX_np[2]
yp = dX_np[3]
zt = dX_np[4]
zp = dX_np[5]
m00 = (xt * xt + yt * yt + zt * zt) / (surf_nt_full * surf_nt_full)
m11 = (xp * xp + yp * yp + zp * zp) / (surf_np_full * surf_np_full)
ratio = jnp.sqrt(m00 / m11)
amin = float(jnp.min(ratio))
amax = float(jnp.max(ratio))
optim_aspect_ratio = math.sqrt(amin * amax) * surf_nt_full / surf_np_full
cond = math.sqrt(amax / amin)
pdim = digits * cond * 1.6
quad_np = self.trg_np * math.ceil(
max(surf_np_full, self.src_np, 2 * pdim + 1) / self.trg_np
)
quad_nt = self.nfp_eff * self.trg_nt * math.ceil(
max(max(surf_nt_full, src_nt_full), optim_aspect_ratio * quad_np)
/ (self.nfp_eff * self.trg_nt)
)
trg_nt_self = surf_nt_full // self.nfp_eff
trg_np_self = surf_np_full
for _ in range(3):
quad_nt_aligned = math.ceil(quad_nt / surf_nt_full) * surf_nt_full
quad_np_aligned = math.ceil(quad_np / surf_np_full) * surf_np_full
X_quad = resample(
self.surface_coord,
surf_nt_full,
surf_np_full,
quad_nt_aligned,
quad_np_aligned,
)
dX_quad = grad2d(X_quad, quad_nt_aligned, quad_np_aligned)
ones = jnp.ones((quad_nt_aligned, quad_np_aligned), dtype=X_quad.dtype)
U = laplace_dx_u_eval_singular(
X_quad,
dX_quad,
ones,
trg_nt_self,
trg_np_self,
self.nfp_eff,
digits=digits,
chunk_size=1024,
)
err = float(jnp.max(jnp.abs(jnp.asarray(U).reshape(-1) - 0.5)))
if err <= 0:
break
scal = max(1.0, (digits + 1) / (math.log(err) / math.log(0.1)))
quad_nt = int(scal * quad_nt_aligned)
quad_np = int(scal * quad_np_aligned)
if err < 10 ** (-digits) or scal < 1.5:
break
quad_np = self.trg_np * round((quad_nt / optim_aspect_ratio) / self.trg_np)
quad_nt = self.nfp_eff * self.trg_nt * round(
(optim_aspect_ratio * quad_np) / (self.nfp_eff * self.trg_nt)
)
return int(quad_nt), int(quad_np)
def _build_quad_setup(self, quad_nt: int, quad_np: int):
surf_nt_full = int(self.surface_coord.shape[1])
surf_np_full = int(self.surface_coord.shape[2])
quad_coord = resample(
self.surface_coord,
surf_nt_full,
surf_np_full,
quad_nt,
quad_np,
)
dX = grad2d(quad_coord, quad_nt, quad_np)
normal, _, orient = surf_normal_area_elem(dX, quad_coord, return_orientation=True)
orient = float(orient)
return QuadSetup(
quad_nt=quad_nt,
quad_np=quad_np,
quad_coord=quad_coord,
dX=dX,
normal=normal,
orient=orient,
)
def _get_patch_idx(self, setup: QuadSetup, digits: int):
cond = float(_surface_cond(setup.dX, setup.quad_nt, setup.quad_np))
patch_dim0 = select_patch_dim(digits, cond)
patch_idx = setup.patch_idx_cache.get(patch_dim0)
if patch_idx is None:
skip_nt = setup.quad_nt // (self.nfp_eff * self.trg_nt)
skip_np = setup.quad_np // self.trg_np
t_idx = jnp.arange(self.trg_nt) * skip_nt
p_idx = jnp.arange(self.trg_np) * skip_np
tt, pp = jnp.meshgrid(t_idx, p_idx, indexing="ij")
patch_idx = _build_patch_indices(
tt.reshape(-1),
pp.reshape(-1),
setup.quad_nt,
setup.quad_np,
patch_dim0,
)
setup.patch_idx_cache[patch_dim0] = patch_idx
return patch_dim0, patch_idx
def _ensure_grad_setup(self, quad_nt: int | None, quad_np: int | None, digits: int):
if not self._setup:
raise RuntimeError("VirtualCasingJAX.setup must be called before compute_external_gradB")
if quad_nt is None or quad_np is None:
quad_nt_sel, quad_np_sel = self._select_quad_sizes(digits)
quad_nt = quad_nt_sel
quad_np = quad_np_sel
if (
self._grad_setup is not None
and self._grad_setup.quad_nt == quad_nt
and self._grad_setup.quad_np == quad_np
):
return
self._grad_setup = self._build_quad_setup(quad_nt, quad_np)
def _ensure_b_setup(self, quad_nt: int | None, quad_np: int | None, digits: int):
if not self._setup:
raise RuntimeError("VirtualCasingJAX.setup must be called before compute_external_B")
if quad_nt is None or quad_np is None:
quad_nt_sel, quad_np_sel = self._select_quad_sizes(digits)
quad_nt = quad_nt_sel
quad_np = quad_np_sel
if (
self._b_setup is not None
and self._b_setup.quad_nt == quad_nt
and self._b_setup.quad_np == quad_np
):
return
self._b_setup = self._build_quad_setup(quad_nt, quad_np)
def _compute_gradB_signed(
self,
B0,
*,
sign: float,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
hedgehog_order: int = 8,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
scan_targets: bool = False,
patch_dim0: int | None = None,
patch_idx=None,
):
if not self._setup:
raise RuntimeError("VirtualCasingJAX.setup must be called before compute_gradB")
digits = self.digits if digits is None else int(digits)
self._ensure_grad_setup(quad_nt, quad_np, digits)
assert self._grad_setup is not None
B0 = jnp.asarray(B0).reshape((3, self.src_nt, self.src_np))
if remat is None:
remat = True
pou_dtype = self._resolve_pou_dtype(pou_dtype, B0.dtype)
patch_dtype = self._resolve_patch_dtype(patch_dtype, B0.dtype)
nsrc = self._grad_setup.quad_nt * self._grad_setup.quad_np
ntrg = self.trg_nt * self.trg_np
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
dtheta = 0.0
if self.half_period:
dtheta = math.pi * (
1.0 / (self.nfp * self.trg_nt * 2) - 1.0 / (self.nfp * self.src_nt * 2)
)
B0_complete = complete_vec_field(
B0,
False,
self.half_period,
self.nfp,
self.src_nt,
self.src_np,
dtheta,
)
B_quad = resample(
B0_complete,
self.nfp_eff * self.src_nt,
self.src_np,
self._grad_setup.quad_nt,
self._grad_setup.quad_np,
)
J = cross_prod(self._grad_setup.normal, B_quad)
BdotN = dot_prod(B_quad, self._grad_setup.normal)
if patch_dim0 is None or patch_idx is None:
patch_dim0, patch_idx = self._get_patch_idx(self._grad_setup, digits)
gradG_J = laplace_fxd2_u_eval_vec_singular(
self._grad_setup.quad_coord,
self._grad_setup.dX,
J,
self.trg_nt,
self.trg_np,
self.nfp_eff,
digits=digits,
patch_dim0=patch_dim0,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=self._grad_setup.orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
scan_targets=scan_targets,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, 3, self.trg_nt, self.trg_np))
gradgradG_BdotN = laplace_fxd2_u_eval_singular(
self._grad_setup.quad_coord,
self._grad_setup.dX,
BdotN,
self.trg_nt,
self.trg_np,
self.nfp_eff,
digits=digits,
patch_dim0=patch_dim0,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=self._grad_setup.orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
scan_targets=scan_targets,
)
gradgradG_BdotN = jnp.asarray(gradgradG_BdotN).reshape(
(3, 3, self.trg_nt, self.trg_np)
)
gradBvc = jnp.zeros((3, 3, self.trg_nt, self.trg_np), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
gradBvc = gradBvc.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
gradBvc = gradBvc + gradgradG_BdotN
return gradBvc * sign
[docs]
def compute_external_gradB(
self,
B0,
*,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
hedgehog_order: int = 8,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
scan_targets: bool = False,
patch_dim0: int | None = None,
patch_idx=None,
):
"""Compute GradBext from total B on the source grid.
Args:
scan_targets: If True, use a ``lax.scan`` over target points
inside the singular correction to reduce peak memory.
"""
return self._compute_gradB_signed(
B0,
sign=1.0,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
scan_targets=scan_targets,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
)
[docs]
def compute_internal_gradB(
self,
B0,
*,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
hedgehog_order: int = 8,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
scan_targets: bool = False,
patch_dim0: int | None = None,
patch_idx=None,
):
"""Compute GradBint from total B on the source grid.
Args:
scan_targets: If True, use a ``lax.scan`` over target points
inside the singular correction to reduce peak memory.
"""
return self._compute_gradB_signed(
B0,
sign=-1.0,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
scan_targets=scan_targets,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
)
[docs]
def compute_external_gradB_jit(self, B0, **kwargs):
"""JIT-compiled version of compute_external_gradB."""
if "X_trg" in kwargs and kwargs["X_trg"] is not None:
raise ValueError("compute_external_gradB_jit does not support X_trg")
donate = bool(kwargs.pop("donate", False))
digits = self.digits if kwargs.get("digits") is None else int(kwargs["digits"])
quad_nt = kwargs.get("quad_nt")
quad_np = kwargs.get("quad_np")
self._ensure_grad_setup(quad_nt, quad_np, digits)
patch_dim0, patch_idx = self._get_patch_idx(self._grad_setup, digits)
chunk_size = kwargs.get("chunk_size", "auto")
target_chunk_size = kwargs.get("target_chunk_size", "auto")
remat = kwargs.get("remat")
pou_dtype = kwargs.get("pou_dtype")
patch_dtype = kwargs.get("patch_dtype")
interp_block_size = kwargs.get("interp_block_size", "auto")
nsrc = self._grad_setup.quad_nt * self._grad_setup.quad_np
ntrg = self.trg_nt * self.trg_np
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
pou_dtype = self._resolve_pou_dtype(pou_dtype, jnp.asarray(B0).dtype)
patch_dtype = self._resolve_patch_dtype(patch_dtype, jnp.asarray(B0).dtype)
key = (
"gradB",
digits,
quad_nt,
quad_np,
kwargs.get("hedgehog_order", 8),
chunk_size,
target_chunk_size,
remat,
pou_dtype,
patch_dtype,
interp_block_size,
kwargs.get("scan_targets", False),
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
call_kwargs = dict(kwargs)
call_kwargs["chunk_size"] = chunk_size
call_kwargs["target_chunk_size"] = target_chunk_size
call_kwargs["remat"] = remat
call_kwargs["pou_dtype"] = pou_dtype
call_kwargs["patch_dtype"] = patch_dtype
call_kwargs["interp_block_size"] = interp_block_size
fn = jax.jit(
lambda B: self.compute_external_gradB(
B,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
**call_kwargs,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0)
def _compute_B_signed(
self,
B0,
*,
sign: float,
X_trg=None,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
patch_dim0: int | None = None,
patch_idx=None,
):
if not self._setup:
raise RuntimeError("VirtualCasingJAX.setup must be called before compute_B")
digits = self.digits if digits is None else int(digits)
self._ensure_b_setup(quad_nt, quad_np, digits)
assert self._b_setup is not None
B0 = jnp.asarray(B0).reshape((3, self.src_nt, self.src_np))
if remat is None:
remat = False
pou_dtype = self._resolve_pou_dtype(pou_dtype, B0.dtype)
patch_dtype = self._resolve_patch_dtype(patch_dtype, B0.dtype)
nsrc = self._b_setup.quad_nt * self._b_setup.quad_np
if X_trg is None:
ntrg = self.trg_nt * self.trg_np
else:
X_trg_arr = jnp.asarray(X_trg)
if X_trg_arr.ndim == 3:
ntrg = X_trg_arr.shape[1] * X_trg_arr.shape[2]
elif X_trg_arr.ndim == 2:
ntrg = X_trg_arr.shape[1]
else:
raise ValueError("X_trg must have shape (3, nt, np) or (3, ntrg)")
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"b", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
dtheta = 0.0
if self.half_period:
dtheta = math.pi * (
1.0 / (self.nfp * self.trg_nt * 2) - 1.0 / (self.nfp * self.src_nt * 2)
)
B0_complete = complete_vec_field(
B0,
False,
self.half_period,
self.nfp,
self.src_nt,
self.src_np,
dtheta,
)
B_quad = resample(
B0_complete,
self.nfp_eff * self.src_nt,
self.src_np,
self._b_setup.quad_nt,
self._b_setup.quad_np,
)
J = cross_prod(self._b_setup.normal, B_quad)
BdotN = dot_prod(B_quad, self._b_setup.normal)
if patch_dim0 is None or patch_idx is None:
patch_dim0, patch_idx = self._get_patch_idx(self._b_setup, digits)
gradG_J = laplace_fxd_u_eval_vec_singular(
self._b_setup.quad_coord,
self._b_setup.dX,
J,
self.trg_nt,
self.trg_np,
self.nfp_eff,
X_trg=X_trg,
digits=digits,
patch_dim0=patch_dim0,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=self._b_setup.orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, self.trg_nt, self.trg_np))
gradG_BdotN = laplace_fxd_u_eval_singular(
self._b_setup.quad_coord,
self._b_setup.dX,
BdotN,
self.trg_nt,
self.trg_np,
self.nfp_eff,
X_trg=X_trg,
digits=digits,
patch_dim0=patch_dim0,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=self._b_setup.orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradG_BdotN = jnp.asarray(gradG_BdotN).reshape((3, self.trg_nt, self.trg_np))
B_on_trg = resample(
B0_complete,
self.nfp_eff * self.src_nt,
self.src_np,
self.nfp_eff * self.trg_nt,
self.trg_np,
)
B_on = B_on_trg[:, : self.trg_nt, :]
Bvc = jnp.zeros((3, self.trg_nt, self.trg_np), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
Bvc = Bvc.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
Bvc = sign * (Bvc + gradG_BdotN) + 0.5 * B_on
return Bvc
[docs]
def compute_external_B(
self,
B0,
*,
X_trg=None,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
patch_dim0: int | None = None,
patch_idx=None,
):
"""Compute Bext from total B on the source grid."""
return self._compute_B_signed(
B0,
sign=1.0,
X_trg=X_trg,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
)
[docs]
def compute_internal_B(
self,
B0,
*,
X_trg=None,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
patch_dim0: int | None = None,
patch_idx=None,
):
"""Compute Bint from total B on the source grid."""
return self._compute_B_signed(
B0,
sign=-1.0,
X_trg=X_trg,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
)
[docs]
def compute_external_B_jit(self, B0, **kwargs):
"""JIT-compiled version of compute_external_B."""
if "X_trg" in kwargs and kwargs["X_trg"] is not None:
raise ValueError("compute_external_B_jit does not support X_trg; jit externally if needed")
donate = bool(kwargs.pop("donate", False))
digits = self.digits if kwargs.get("digits") is None else int(kwargs["digits"])
quad_nt = kwargs.get("quad_nt")
quad_np = kwargs.get("quad_np")
self._ensure_b_setup(quad_nt, quad_np, digits)
patch_dim0, patch_idx = self._get_patch_idx(self._b_setup, digits)
chunk_size = kwargs.get("chunk_size", "auto")
target_chunk_size = kwargs.get("target_chunk_size", "auto")
remat = kwargs.get("remat")
pou_dtype = kwargs.get("pou_dtype")
patch_dtype = kwargs.get("patch_dtype")
interp_block_size = kwargs.get("interp_block_size", "auto")
nsrc = self._b_setup.quad_nt * self._b_setup.quad_np
ntrg = self.trg_nt * self.trg_np
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"b", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
pou_dtype = self._resolve_pou_dtype(pou_dtype, jnp.asarray(B0).dtype)
patch_dtype = self._resolve_patch_dtype(patch_dtype, jnp.asarray(B0).dtype)
key = (
"B",
digits,
quad_nt,
quad_np,
chunk_size,
target_chunk_size,
remat,
pou_dtype,
patch_dtype,
interp_block_size,
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
call_kwargs = dict(kwargs)
call_kwargs["chunk_size"] = chunk_size
call_kwargs["target_chunk_size"] = target_chunk_size
call_kwargs["remat"] = remat
call_kwargs["pou_dtype"] = pou_dtype
call_kwargs["patch_dtype"] = patch_dtype
call_kwargs["interp_block_size"] = interp_block_size
fn = jax.jit(
lambda B: self.compute_external_B(
B,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
**call_kwargs,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0)
[docs]
def compute_internal_B_jit(self, B0, **kwargs):
"""JIT-compiled version of compute_internal_B."""
if "X_trg" in kwargs and kwargs["X_trg"] is not None:
raise ValueError("compute_internal_B_jit does not support X_trg; jit externally if needed")
donate = bool(kwargs.pop("donate", False))
digits = self.digits if kwargs.get("digits") is None else int(kwargs["digits"])
quad_nt = kwargs.get("quad_nt")
quad_np = kwargs.get("quad_np")
self._ensure_b_setup(quad_nt, quad_np, digits)
patch_dim0, patch_idx = self._get_patch_idx(self._b_setup, digits)
chunk_size = kwargs.get("chunk_size", "auto")
target_chunk_size = kwargs.get("target_chunk_size", "auto")
remat = kwargs.get("remat")
pou_dtype = kwargs.get("pou_dtype")
patch_dtype = kwargs.get("patch_dtype")
interp_block_size = kwargs.get("interp_block_size", "auto")
nsrc = self._b_setup.quad_nt * self._b_setup.quad_np
ntrg = self.trg_nt * self.trg_np
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"b", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
pou_dtype = self._resolve_pou_dtype(pou_dtype, jnp.asarray(B0).dtype)
patch_dtype = self._resolve_patch_dtype(patch_dtype, jnp.asarray(B0).dtype)
key = (
"Bint",
digits,
quad_nt,
quad_np,
chunk_size,
target_chunk_size,
remat,
pou_dtype,
patch_dtype,
interp_block_size,
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
call_kwargs = dict(kwargs)
call_kwargs["chunk_size"] = chunk_size
call_kwargs["target_chunk_size"] = target_chunk_size
call_kwargs["remat"] = remat
call_kwargs["pou_dtype"] = pou_dtype
call_kwargs["patch_dtype"] = patch_dtype
call_kwargs["interp_block_size"] = interp_block_size
fn = jax.jit(
lambda B: self.compute_internal_B(
B,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
**call_kwargs,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0)
[docs]
def compute_external_B_batch(self, B0_batch, *, X_trg=None, **kwargs):
"""Vectorized compute_external_B over a batch dimension."""
if X_trg is None:
return jax.vmap(lambda b: self.compute_external_B(b, **kwargs), in_axes=0)(B0_batch)
return jax.vmap(lambda b, xt: self.compute_external_B(b, X_trg=xt, **kwargs), in_axes=(0, 0))(B0_batch, X_trg)
[docs]
def compute_internal_B_batch(self, B0_batch, *, X_trg=None, **kwargs):
"""Vectorized compute_internal_B over a batch dimension."""
if X_trg is None:
return jax.vmap(lambda b: self.compute_internal_B(b, **kwargs), in_axes=0)(B0_batch)
return jax.vmap(lambda b, xt: self.compute_internal_B(b, X_trg=xt, **kwargs), in_axes=(0, 0))(B0_batch, X_trg)
[docs]
def compute_external_gradB_batch(self, B0_batch, **kwargs):
"""Vectorized compute_external_gradB over a batch dimension."""
return jax.vmap(lambda b: self.compute_external_gradB(b, **kwargs), in_axes=0)(B0_batch)
[docs]
def compute_internal_gradB_batch(self, B0_batch, **kwargs):
"""Vectorized compute_internal_gradB over a batch dimension."""
return jax.vmap(lambda b: self.compute_internal_gradB(b, **kwargs), in_axes=0)(B0_batch)
[docs]
def compute_internal_gradB_jit(self, B0, **kwargs):
"""JIT-compiled version of compute_internal_gradB."""
if "X_trg" in kwargs and kwargs["X_trg"] is not None:
raise ValueError("compute_internal_gradB_jit does not support X_trg")
donate = bool(kwargs.pop("donate", False))
digits = self.digits if kwargs.get("digits") is None else int(kwargs["digits"])
quad_nt = kwargs.get("quad_nt")
quad_np = kwargs.get("quad_np")
self._ensure_grad_setup(quad_nt, quad_np, digits)
patch_dim0, patch_idx = self._get_patch_idx(self._grad_setup, digits)
chunk_size = kwargs.get("chunk_size", "auto")
target_chunk_size = kwargs.get("target_chunk_size", "auto")
remat = kwargs.get("remat")
pou_dtype = kwargs.get("pou_dtype")
patch_dtype = kwargs.get("patch_dtype")
interp_block_size = kwargs.get("interp_block_size", "auto")
nsrc = self._grad_setup.quad_nt * self._grad_setup.quad_np
ntrg = self.trg_nt * self.trg_np
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
pou_dtype = self._resolve_pou_dtype(pou_dtype, jnp.asarray(B0).dtype)
patch_dtype = self._resolve_patch_dtype(patch_dtype, jnp.asarray(B0).dtype)
key = (
"gradBint",
digits,
quad_nt,
quad_np,
kwargs.get("hedgehog_order", 8),
chunk_size,
target_chunk_size,
remat,
pou_dtype,
patch_dtype,
interp_block_size,
kwargs.get("scan_targets", False),
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
call_kwargs = dict(kwargs)
call_kwargs["chunk_size"] = chunk_size
call_kwargs["target_chunk_size"] = target_chunk_size
call_kwargs["remat"] = remat
call_kwargs["pou_dtype"] = pou_dtype
call_kwargs["patch_dtype"] = patch_dtype
call_kwargs["interp_block_size"] = interp_block_size
fn = jax.jit(
lambda B: self.compute_internal_gradB(
B,
patch_dim0=patch_dim0,
patch_idx=patch_idx,
**call_kwargs,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0)
[docs]
def compute_external_B_autodiff(
self,
B0,
*,
X_trg,
quad_nt: int | None = None,
quad_np: int | None = None,
digits: int | None = None,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool | None = None,
hedgehog_order: int = 8,
):
"""Compute Bext with a custom JVP that matches ComputeGradB on-surface."""
if X_trg is None:
raise ValueError("X_trg must be provided for autodiff-enabled evaluation")
B0 = jnp.asarray(B0)
X_trg = jnp.asarray(X_trg)
digits = self.digits if digits is None else int(digits)
@jax.custom_jvp
def _eval(xtrg):
return self.compute_external_B(
B0,
X_trg=xtrg,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
remat=remat,
)
@_eval.defjvp
def _eval_jvp(primals, tangents):
(xtrg,) = primals
(dxtrg,) = tangents
b = self.compute_external_B(
B0,
X_trg=xtrg,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
gradb = self.compute_external_gradB(
B0,
quad_nt=quad_nt,
quad_np=quad_np,
digits=digits,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
)
db = jnp.einsum("k i t p, i t p -> k t p", gradb, dxtrg)
return b, db
return _eval(X_trg)
def _offsurface_densities(self, B0):
if not self._setup:
raise RuntimeError("VirtualCasingJAX.setup must be called before off-surface evaluation")
B0 = jnp.asarray(B0).reshape((3, self.src_nt, self.src_np))
surf_nt_full = int(self.surface_coord.shape[1])
surf_np_full = int(self.surface_coord.shape[2])
patch_dim = 13 # 2*6+1 to match BIEST off-surface minimum
base_nt = max(self.nfp_eff * self.src_nt, surf_nt_full, patch_dim)
base_np = max(self.src_np, surf_np_full, patch_dim)
X_src = resample(
self.surface_coord,
surf_nt_full,
surf_np_full,
base_nt,
base_np,
)
dX = grad2d(X_src, base_nt, base_np)
normal, _ = surf_normal_area_elem(dX, X_src)
dtheta = 0.0
if self.half_period:
dtheta = math.pi * (
1.0 / (self.nfp * self.trg_nt * 2) - 1.0 / (self.nfp * self.src_nt * 2)
)
B0_complete = complete_vec_field(
B0,
False,
self.half_period,
self.nfp,
self.src_nt,
self.src_np,
dtheta,
)
B_quad = resample(
B0_complete,
self.nfp_eff * self.src_nt,
self.src_np,
base_nt,
base_np,
)
J = cross_prod(normal, B_quad)
BdotN = dot_prod(B_quad, normal)
return X_src, BdotN, J
[docs]
def compute_external_B_offsurf(
self,
B0,
*,
X_trg,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute Bext at off-surface targets using adaptive quadrature."""
digits = self.digits if digits is None else int(digits)
X_src, BdotN, J = self._offsurface_densities(B0)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
out = computeB_offsurface_adaptive(
X_src,
BdotN,
J,
X_trg,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
ext=True,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if X_trg.ndim == 3:
return jnp.asarray(out).reshape((3, X_trg.shape[1], X_trg.shape[2]))
return out
[docs]
def compute_internal_B_offsurf(
self,
B0,
*,
X_trg,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute Bint at off-surface targets using adaptive quadrature."""
digits = self.digits if digits is None else int(digits)
X_src, BdotN, J = self._offsurface_densities(B0)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
out = computeB_offsurface_adaptive(
X_src,
BdotN,
J,
X_trg,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
ext=False,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if X_trg.ndim == 3:
return jnp.asarray(out).reshape((3, X_trg.shape[1], X_trg.shape[2]))
return out
[docs]
def compute_external_B_offsurf_schedule(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute Bext off-surface using a fixed adaptive refinement schedule."""
digits = self.digits if digits is None else int(digits)
X_src, BdotN, J = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
out = computeB_offsurface_adaptive_schedule(
X_src,
BdotN,
J,
X_trg,
levels=levels,
digits=digits,
ext=True,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if X_trg.ndim == 3:
return jnp.asarray(out).reshape((3, X_trg.shape[1], X_trg.shape[2]))
return out
[docs]
def compute_external_B_offsurf_schedule_jit(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None = "auto",
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
donate: bool = False,
):
"""JIT-compiled schedule-based off-surface Bext."""
digits = self.digits if digits is None else int(digits)
X_src, _, _ = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
key = (
"Boff_schedule",
digits,
levels,
chunk_size,
target_chunk_size,
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
fn = jax.jit(
lambda B, Xt: self.compute_external_B_offsurf_schedule(
B,
X_trg=Xt,
levels=levels,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0, X_trg)
[docs]
def compute_internal_B_offsurf_schedule(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute Bint off-surface using a fixed adaptive refinement schedule."""
digits = self.digits if digits is None else int(digits)
X_src, BdotN, J = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
out = computeB_offsurface_adaptive_schedule(
X_src,
BdotN,
J,
X_trg,
levels=levels,
digits=digits,
ext=False,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if X_trg.ndim == 3:
return jnp.asarray(out).reshape((3, X_trg.shape[1], X_trg.shape[2]))
return out
[docs]
def compute_internal_B_offsurf_schedule_jit(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None = "auto",
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
donate: bool = False,
):
"""JIT-compiled schedule-based off-surface Bint."""
digits = self.digits if digits is None else int(digits)
X_src, _, _ = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"boff", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
key = (
"Bint_off_schedule",
digits,
levels,
chunk_size,
target_chunk_size,
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
fn = jax.jit(
lambda B, Xt: self.compute_internal_B_offsurf_schedule(
B,
X_trg=Xt,
levels=levels,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0, X_trg)
[docs]
def compute_external_gradB_offsurf(
self,
B0,
*,
X_trg,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
adaptive: bool = False,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute GradBext at off-surface targets using direct quadrature.
The off-surface GradB path mirrors the reference implementation and
currently uses the base resampled grid (no adaptive refinement).
"""
digits = self.digits if digits is None else int(digits)
X_trg = jnp.asarray(X_trg)
X_trg_flat = X_trg.reshape((3, -1)) if X_trg.ndim == 3 else X_trg
ntrg = X_trg_flat.shape[1]
X_src, BdotN, J = self._offsurface_densities(B0)
nsrc = X_src.shape[1] * X_src.shape[2]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb_off", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
if adaptive:
X_src, BdotN, J, area_elem = _offsurface_adapt_grid(
X_src,
BdotN,
J,
X_trg_flat,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
else:
dX = grad2d(X_src, X_src.shape[1], X_src.shape[2])
_, area_elem = surf_normal_area_elem(dX, X_src)
gradG_J = laplace_fxd2_u_eval_vec(
X_src,
X_trg_flat,
J,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, 3, ntrg))
gradgradG_BdotN = laplace_fxd2_u_eval(
X_src,
X_trg_flat,
BdotN,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradgradG_BdotN = jnp.asarray(gradgradG_BdotN).reshape((3, 3, ntrg))
gradB = jnp.zeros((3, 3, ntrg), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
gradB = gradB.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
gradB = gradB + gradgradG_BdotN
if X_trg.ndim == 3:
return gradB.reshape((3, 3, X_trg.shape[1], X_trg.shape[2]))
return gradB
[docs]
def compute_internal_gradB_offsurf(
self,
B0,
*,
X_trg,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
adaptive: bool = False,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute GradBint at off-surface targets using direct quadrature."""
gradB = self.compute_external_gradB_offsurf(
B0,
X_trg=X_trg,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
adaptive=adaptive,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
return -gradB
[docs]
def compute_external_gradB_offsurf_schedule(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute GradBext off-surface using a fixed adaptive refinement schedule."""
digits = self.digits if digits is None else int(digits)
X_src, BdotN, J = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
X_trg_flat = X_trg.reshape((3, -1)) if X_trg.ndim == 3 else X_trg
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg_flat.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb_off", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
gradB = computeGradB_offsurface_adaptive_schedule(
X_src,
BdotN,
J,
X_trg_flat,
levels=levels,
digits=digits,
ext=True,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if X_trg.ndim == 3:
return jnp.asarray(gradB).reshape((3, 3, X_trg.shape[1], X_trg.shape[2]))
return gradB
[docs]
def compute_external_gradB_offsurf_schedule_jit(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None = "auto",
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
donate: bool = False,
):
"""JIT-compiled schedule-based off-surface GradBext."""
digits = self.digits if digits is None else int(digits)
X_src, _, _ = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb_off", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
key = (
"GradBoff_schedule",
digits,
levels,
chunk_size,
target_chunk_size,
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
fn = jax.jit(
lambda B, Xt: self.compute_external_gradB_offsurf_schedule(
B,
X_trg=Xt,
levels=levels,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0, X_trg)
[docs]
def compute_internal_gradB_offsurf_schedule(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None,
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
):
"""Compute GradBint off-surface using a fixed adaptive refinement schedule."""
digits = self.digits if digits is None else int(digits)
X_src, BdotN, J = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
X_trg_flat = X_trg.reshape((3, -1)) if X_trg.ndim == 3 else X_trg
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg_flat.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb_off", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
gradB = computeGradB_offsurface_adaptive_schedule(
X_src,
BdotN,
J,
X_trg_flat,
levels=levels,
digits=digits,
ext=False,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if X_trg.ndim == 3:
return jnp.asarray(gradB).reshape((3, 3, X_trg.shape[1], X_trg.shape[2]))
return gradB
[docs]
def compute_internal_gradB_offsurf_schedule_jit(
self,
B0,
*,
X_trg,
levels: tuple[tuple[int, int], ...] | str | None = "auto",
digits: int | None = None,
max_Nt: int = -1,
max_Np: int = -1,
max_levels: int = 6,
chunk_size: int | str | None = "auto",
target_chunk_size: int | str | None = "auto",
donate: bool = False,
):
"""JIT-compiled schedule-based off-surface GradBint."""
digits = self.digits if digits is None else int(digits)
X_src, _, _ = self._offsurface_densities(B0)
levels = self._resolve_offsurface_levels(
levels,
nt0=int(X_src.shape[1]),
np0=int(X_src.shape[2]),
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
)
X_trg = jnp.asarray(X_trg)
nsrc = X_src.shape[1] * X_src.shape[2]
ntrg = X_trg.shape[1] * X_trg.shape[2] if X_trg.ndim == 3 else X_trg.shape[1]
chunk_size, target_chunk_size = self._resolve_chunk_sizes(
"gradb_off", chunk_size, target_chunk_size, nsrc=nsrc, ntrg=ntrg
)
key = (
"GradBint_off_schedule",
digits,
levels,
chunk_size,
target_chunk_size,
donate,
)
fn = self._jit_cache.get(key)
if fn is None:
fn = jax.jit(
lambda B, Xt: self.compute_internal_gradB_offsurf_schedule(
B,
X_trg=Xt,
levels=levels,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
max_levels=max_levels,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
),
donate_argnums=(0,) if donate else (),
)
self._jit_cache[key] = fn
return fn(B0, X_trg)