"""Boundary integral evaluation (baseline direct-sum)."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from .kernels import laplace_fxd_u, laplace_fxd2_u, biotsavart_fx_u, biotsavart_fxd_u, laplace_dx_u
from .surface_ops import upsample, resample, grad2d, surf_normal_area_elem, normal_orientation
from .singular_quadrature import precompute_singular, select_patch_dim, INTERP_ORDER
def _flatten_soa(x, name: str):
x = jnp.asarray(x)
if x.ndim == 3:
return x.reshape((x.shape[0], -1))
if x.ndim == 2:
return x
raise ValueError(f"{name} must have shape (dof, nt, np) or (dof, n)")
def _pad_to_multiple(x, axis: int, chunk: int):
if chunk is None or chunk <= 0:
return x, 0
n = x.shape[axis]
pad = (-n) % chunk
if pad:
pad_width = [(0, 0)] * x.ndim
pad_width[axis] = (0, pad)
x = jnp.pad(x, pad_width)
return x, pad
[docs]
def field_period_target_coords(X_quad, trg_nt: int, trg_np: int, nfp: int):
"""Select target coordinates used by FieldPeriodBIOp.
X_quad: (3, quad_nt, quad_np) for the full NFP surface.
Returns X_trg: (3, trg_nt, trg_np) for the first field period.
"""
X_quad = jnp.asarray(X_quad)
quad_nt = X_quad.shape[1]
quad_np = X_quad.shape[2]
if quad_nt % (nfp * trg_nt) != 0:
raise ValueError("quad_nt must be divisible by nfp*trg_nt")
if quad_np % trg_np != 0:
raise ValueError("quad_np must be divisible by trg_np")
skip_nt = quad_nt // (nfp * trg_nt)
skip_np = quad_np // trg_np
X_trg_full = X_quad[:, ::skip_nt, ::skip_np]
return X_trg_full[:, :trg_nt, :]
[docs]
def laplace_fxd_u_eval(
X_src,
X_trg,
density,
area_elem,
*,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Evaluate Laplace FxdU (grad single-layer) by direct quadrature.
X_src: (3, nt, np) or (3, nsrc)
X_trg: (3, nt, np) or (3, ntrg)
density: (nt, np) or (nsrc,)
area_elem: (nt, np) or (nsrc,)
Returns: (3, ntrg) or (3, nt, np) matching X_trg layout.
"""
X_src = _flatten_soa(X_src, "X_src")
X_trg = _flatten_soa(X_trg, "X_trg")
density = jnp.asarray(density).reshape(-1)
area_elem = jnp.asarray(area_elem).reshape(-1)
if X_src.shape[0] != 3 or X_trg.shape[0] != 3:
raise ValueError("X_src and X_trg must be 3D coordinates in SoA layout")
nsrc = X_src.shape[1]
ntrg = X_trg.shape[1]
if density.shape[0] != nsrc or area_elem.shape[0] != nsrc:
raise ValueError("density/area_elem must match source grid size")
weights = density * area_elem
Xs = jnp.transpose(X_src, (1, 0)) # (nsrc, 3)
Xt = jnp.transpose(X_trg, (1, 0)) # (ntrg, 3)
if chunk_size is None or chunk_size <= 0:
if target_chunk_size is None or target_chunk_size <= 0:
dx = Xt[:, None, :] - Xs[None, :, :]
contrib = laplace_fxd_u(dx, weights)
out = jnp.sum(contrib, axis=1)
return jnp.transpose(out, (1, 0))
chunk_size = 0
# Pad sources to chunk size
Xs, pad_src = _pad_to_multiple(Xs, 0, chunk_size if chunk_size > 0 else 1)
if pad_src:
weights = jnp.pad(weights, (0, pad_src))
nsrc_pad = Xs.shape[0]
n_chunks = 1 if chunk_size <= 0 else nsrc_pad // chunk_size
if chunk_size <= 0:
X_chunks = Xs.reshape((1, nsrc_pad, 3))
w_chunks = weights.reshape((1, nsrc_pad))
else:
X_chunks = Xs.reshape((n_chunks, chunk_size, 3))
w_chunks = weights.reshape((n_chunks, chunk_size))
if target_chunk_size is None or target_chunk_size <= 0:
def scan_fn(acc, xs):
Xc, wc = xs
dx = Xt[:, None, :] - Xc[None, :, :]
contrib = laplace_fxd_u(dx, wc)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((ntrg, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, w_chunks))
return jnp.transpose(out, (1, 0))
# Target blocking
Xt, pad_trg = _pad_to_multiple(Xt, 0, target_chunk_size)
ntrg_pad = Xt.shape[0]
n_tchunks = ntrg_pad // target_chunk_size
Xt_chunks = Xt.reshape((n_tchunks, target_chunk_size, 3))
def eval_chunk(Xt_chunk):
def scan_fn(acc, xs):
Xc, wc = xs
dx = Xt_chunk[:, None, :] - Xc[None, :, :]
contrib = laplace_fxd_u(dx, wc)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((target_chunk_size, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, w_chunks))
return out
_, outs = jax.lax.scan(lambda c, x: (c, eval_chunk(x)), None, Xt_chunks)
out = outs.reshape((ntrg_pad, 3))[:ntrg]
return jnp.transpose(out, (1, 0))
[docs]
def laplace_fxd2_u_eval(
X_src,
X_trg,
density,
area_elem,
*,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Evaluate Laplace Fxd2U (second derivatives) by direct quadrature."""
X_src = _flatten_soa(X_src, "X_src")
X_trg = _flatten_soa(X_trg, "X_trg")
density = jnp.asarray(density).reshape(-1)
area_elem = jnp.asarray(area_elem).reshape(-1)
if X_src.shape[0] != 3 or X_trg.shape[0] != 3:
raise ValueError("X_src and X_trg must be 3D in SoA layout")
nsrc = X_src.shape[1]
ntrg = X_trg.shape[1]
if density.shape[0] != nsrc or area_elem.shape[0] != nsrc:
raise ValueError("density/area_elem must match source grid size")
weights = density * area_elem
Xs = jnp.transpose(X_src, (1, 0))
Xt = jnp.transpose(X_trg, (1, 0))
if chunk_size is None or chunk_size <= 0:
if target_chunk_size is None or target_chunk_size <= 0:
dx = Xt[:, None, :] - Xs[None, :, :]
contrib = laplace_fxd2_u(dx, weights)
out = jnp.sum(contrib, axis=1)
out = out.reshape((ntrg, 9))
return jnp.transpose(out, (1, 0))
chunk_size = 0
Xs, pad_src = _pad_to_multiple(Xs, 0, chunk_size if chunk_size > 0 else 1)
if pad_src:
weights = jnp.pad(weights, (0, pad_src))
nsrc_pad = Xs.shape[0]
n_chunks = 1 if chunk_size <= 0 else nsrc_pad // chunk_size
if chunk_size <= 0:
X_chunks = Xs.reshape((1, nsrc_pad, 3))
W_chunks = weights.reshape((1, nsrc_pad))
else:
X_chunks = Xs.reshape((n_chunks, chunk_size, 3))
W_chunks = weights.reshape((n_chunks, chunk_size))
if target_chunk_size is None or target_chunk_size <= 0:
def scan_fn(acc, xs):
Xc, Wc = xs
dx = Xt[:, None, :] - Xc[None, :, :]
contrib = laplace_fxd2_u(dx, Wc)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((ntrg, 3, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, W_chunks))
out = out.reshape((ntrg, 9))
return jnp.transpose(out, (1, 0))
Xt, pad_trg = _pad_to_multiple(Xt, 0, target_chunk_size)
ntrg_pad = Xt.shape[0]
n_tchunks = ntrg_pad // target_chunk_size
Xt_chunks = Xt.reshape((n_tchunks, target_chunk_size, 3))
def eval_chunk(Xt_chunk):
def scan_fn(acc, xs):
Xc, Wc = xs
dx = Xt_chunk[:, None, :] - Xc[None, :, :]
contrib = laplace_fxd2_u(dx, Wc)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((target_chunk_size, 3, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, W_chunks))
out = out.reshape((target_chunk_size, 9))
return out
_, outs = jax.lax.scan(lambda c, x: (c, eval_chunk(x)), None, Xt_chunks)
out = outs.reshape((ntrg_pad, 9))[:ntrg]
return jnp.transpose(out, (1, 0))
[docs]
def laplace_fxd2_u_eval_vec(
X_src,
X_trg,
density_vec,
area_elem,
*,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Vector-density wrapper for Laplace Fxd2U."""
density_vec = _flatten_soa(density_vec, "density_vec")
return jax.vmap(
lambda dens: laplace_fxd2_u_eval(
X_src,
X_trg,
dens,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
),
in_axes=0,
out_axes=0,
)(density_vec)
[docs]
def laplace_fxd_u_eval_vec(
X_src,
X_trg,
density_vec,
area_elem,
*,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Vector-density wrapper for Laplace FxdU.
density_vec: (3, nt, np) or (3, nsrc)
Returns: (3, 3, ntrg) with first index over density component.
"""
density_vec = _flatten_soa(density_vec, "density_vec")
return jax.vmap(
lambda dens: laplace_fxd_u_eval(
X_src,
X_trg,
dens,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
),
in_axes=0,
out_axes=0,
)(density_vec)
[docs]
def biotsavart_fx_u_eval(
X_src,
X_trg,
density_vec,
area_elem,
*,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Evaluate Biot-Savart FxU by direct quadrature.
density_vec: (3, nt, np) or (3, nsrc)
Returns: (3, ntrg) or (3, nt, np) matching X_trg layout.
"""
X_src = _flatten_soa(X_src, "X_src")
X_trg = _flatten_soa(X_trg, "X_trg")
density_vec = _flatten_soa(density_vec, "density_vec")
area_elem = jnp.asarray(area_elem).reshape(-1)
if X_src.shape[0] != 3 or X_trg.shape[0] != 3 or density_vec.shape[0] != 3:
raise ValueError("X_src, X_trg, density_vec must be 3D in SoA layout")
nsrc = X_src.shape[1]
ntrg = X_trg.shape[1]
if area_elem.shape[0] != nsrc:
raise ValueError("area_elem must match source grid size")
weights = density_vec * area_elem[None, :]
Xs = jnp.transpose(X_src, (1, 0)) # (nsrc, 3)
Xt = jnp.transpose(X_trg, (1, 0)) # (ntrg, 3)
if chunk_size is None or chunk_size <= 0:
if target_chunk_size is None or target_chunk_size <= 0:
dx = Xt[:, None, :] - Xs[None, :, :]
fvec = jnp.transpose(weights, (1, 0))[None, :, :]
contrib = biotsavart_fx_u(dx, fvec)
out = jnp.sum(contrib, axis=1)
return jnp.transpose(out, (1, 0))
chunk_size = 0
Xs, pad_src = _pad_to_multiple(Xs, 0, chunk_size if chunk_size > 0 else 1)
if pad_src:
weights = jnp.pad(weights, ((0, 0), (0, pad_src)))
nsrc_pad = Xs.shape[0]
n_chunks = 1 if chunk_size <= 0 else nsrc_pad // chunk_size
if chunk_size <= 0:
X_chunks = Xs.reshape((1, nsrc_pad, 3))
W_chunks = weights.reshape((3, 1, nsrc_pad)).transpose(1, 2, 0)
else:
X_chunks = Xs.reshape((n_chunks, chunk_size, 3))
W_chunks = weights.reshape((3, n_chunks, chunk_size)).transpose(1, 2, 0)
if target_chunk_size is None or target_chunk_size <= 0:
def scan_fn(acc, xs):
Xc, Wc = xs
dx = Xt[:, None, :] - Xc[None, :, :]
fvec = Wc[None, :, :]
contrib = biotsavart_fx_u(dx, fvec)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((ntrg, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, W_chunks))
return jnp.transpose(out, (1, 0))
Xt, pad_trg = _pad_to_multiple(Xt, 0, target_chunk_size)
ntrg_pad = Xt.shape[0]
n_tchunks = ntrg_pad // target_chunk_size
Xt_chunks = Xt.reshape((n_tchunks, target_chunk_size, 3))
def eval_chunk(Xt_chunk):
def scan_fn(acc, xs):
Xc, Wc = xs
dx = Xt_chunk[:, None, :] - Xc[None, :, :]
fvec = Wc[None, :, :]
contrib = biotsavart_fx_u(dx, fvec)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((target_chunk_size, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, W_chunks))
return out
_, outs = jax.lax.scan(lambda c, x: (c, eval_chunk(x)), None, Xt_chunks)
out = outs.reshape((ntrg_pad, 3))[:ntrg]
return jnp.transpose(out, (1, 0))
[docs]
def biotsavart_fxd_u_eval(
X_src,
X_trg,
density_vec,
area_elem,
*,
chunk_size: int = 512,
target_chunk_size: int | None = None,
):
"""Evaluate Biot-Savart FxdU by direct quadrature.
density_vec: (3, nt, np) or (3, nsrc)
Returns: (3, 3, ntrg) or (3, 3, nt, np) matching X_trg layout.
"""
X_src = _flatten_soa(X_src, "X_src")
X_trg = _flatten_soa(X_trg, "X_trg")
density_vec = _flatten_soa(density_vec, "density_vec")
area_elem = jnp.asarray(area_elem).reshape(-1)
if X_src.shape[0] != 3 or X_trg.shape[0] != 3 or density_vec.shape[0] != 3:
raise ValueError("X_src, X_trg, density_vec must be 3D in SoA layout")
nsrc = X_src.shape[1]
ntrg = X_trg.shape[1]
if area_elem.shape[0] != nsrc:
raise ValueError("area_elem must match source grid size")
weights = density_vec * area_elem[None, :]
Xs = jnp.transpose(X_src, (1, 0)) # (nsrc, 3)
Xt = jnp.transpose(X_trg, (1, 0)) # (ntrg, 3)
if chunk_size is None or chunk_size <= 0:
if target_chunk_size is None or target_chunk_size <= 0:
dx = Xt[:, None, :] - Xs[None, :, :]
fvec = jnp.transpose(weights, (1, 0))[None, :, :]
contrib = biotsavart_fxd_u(dx, fvec)
out = jnp.sum(contrib, axis=1)
return jnp.transpose(out, (1, 2, 0))
chunk_size = 0
Xs, pad_src = _pad_to_multiple(Xs, 0, chunk_size if chunk_size > 0 else 1)
if pad_src:
weights = jnp.pad(weights, ((0, 0), (0, pad_src)))
nsrc_pad = Xs.shape[0]
n_chunks = 1 if chunk_size <= 0 else nsrc_pad // chunk_size
if chunk_size <= 0:
X_chunks = Xs.reshape((1, nsrc_pad, 3))
W_chunks = weights.reshape((3, 1, nsrc_pad)).transpose(1, 2, 0)
else:
X_chunks = Xs.reshape((n_chunks, chunk_size, 3))
W_chunks = weights.reshape((3, n_chunks, chunk_size)).transpose(1, 2, 0)
if target_chunk_size is None or target_chunk_size <= 0:
def scan_fn(acc, xs):
Xc, Wc = xs
dx = Xt[:, None, :] - Xc[None, :, :]
fvec = Wc[None, :, :]
contrib = biotsavart_fxd_u(dx, fvec)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((ntrg, 3, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, W_chunks))
return jnp.transpose(out, (1, 2, 0))
Xt, pad_trg = _pad_to_multiple(Xt, 0, target_chunk_size)
ntrg_pad = Xt.shape[0]
n_tchunks = ntrg_pad // target_chunk_size
Xt_chunks = Xt.reshape((n_tchunks, target_chunk_size, 3))
def eval_chunk(Xt_chunk):
def scan_fn(acc, xs):
Xc, Wc = xs
dx = Xt_chunk[:, None, :] - Xc[None, :, :]
fvec = Wc[None, :, :]
contrib = biotsavart_fxd_u(dx, fvec)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((target_chunk_size, 3, 3), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, W_chunks))
return out
_, outs = jax.lax.scan(lambda c, x: (c, eval_chunk(x)), None, Xt_chunks)
out = outs.reshape((ntrg_pad, 3, 3))[:ntrg]
return jnp.transpose(out, (1, 2, 0))
[docs]
def laplace_dx_u_eval(
X_src,
n_src,
X_trg,
density,
area_elem,
*,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Evaluate Laplace DxU (double-layer) by direct quadrature."""
X_src = _flatten_soa(X_src, "X_src")
n_src = _flatten_soa(n_src, "n_src")
X_trg = _flatten_soa(X_trg, "X_trg")
density = jnp.asarray(density).reshape(-1)
area_elem = jnp.asarray(area_elem).reshape(-1)
if X_src.shape[0] != 3 or X_trg.shape[0] != 3 or n_src.shape[0] != 3:
raise ValueError("X_src, X_trg, n_src must be 3D in SoA layout")
nsrc = X_src.shape[1]
ntrg = X_trg.shape[1]
if density.shape[0] != nsrc or area_elem.shape[0] != nsrc:
raise ValueError("density/area_elem must match source grid size")
weights = density * area_elem
Xs = jnp.transpose(X_src, (1, 0))
Ns = jnp.transpose(n_src, (1, 0))
Xt = jnp.transpose(X_trg, (1, 0))
if chunk_size is None or chunk_size <= 0:
if target_chunk_size is None or target_chunk_size <= 0:
dx = Xt[:, None, :] - Xs[None, :, :]
n = Ns[None, :, :]
contrib = laplace_dx_u(dx, n, weights)
out = jnp.sum(contrib, axis=1)
return out.reshape((1, ntrg))
chunk_size = 0
Xs, pad_src = _pad_to_multiple(Xs, 0, chunk_size if chunk_size > 0 else 1)
if pad_src:
Ns = jnp.pad(Ns, ((0, pad_src), (0, 0)))
weights = jnp.pad(weights, (0, pad_src))
nsrc_pad = Xs.shape[0]
n_chunks = 1 if chunk_size <= 0 else nsrc_pad // chunk_size
if chunk_size <= 0:
X_chunks = Xs.reshape((1, nsrc_pad, 3))
N_chunks = Ns.reshape((1, nsrc_pad, 3))
W_chunks = weights.reshape((1, nsrc_pad))
else:
X_chunks = Xs.reshape((n_chunks, chunk_size, 3))
N_chunks = Ns.reshape((n_chunks, chunk_size, 3))
W_chunks = weights.reshape((n_chunks, chunk_size))
if target_chunk_size is None or target_chunk_size <= 0:
def scan_fn(acc, xs):
Xc, Nc, Wc = xs
dx = Xt[:, None, :] - Xc[None, :, :]
n = Nc[None, :, :]
contrib = laplace_dx_u(dx, n, Wc)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((ntrg,), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, N_chunks, W_chunks))
return out.reshape((1, ntrg))
Xt, pad_trg = _pad_to_multiple(Xt, 0, target_chunk_size)
ntrg_pad = Xt.shape[0]
n_tchunks = ntrg_pad // target_chunk_size
Xt_chunks = Xt.reshape((n_tchunks, target_chunk_size, 3))
def eval_chunk(Xt_chunk):
def scan_fn(acc, xs):
Xc, Nc, Wc = xs
dx = Xt_chunk[:, None, :] - Xc[None, :, :]
n = Nc[None, :, :]
contrib = laplace_dx_u(dx, n, Wc)
acc = acc + jnp.sum(contrib, axis=1)
return acc, None
init = jnp.zeros((target_chunk_size,), dtype=Xs.dtype)
out, _ = jax.lax.scan(scan_fn, init, (X_chunks, N_chunks, W_chunks))
return out
_, outs = jax.lax.scan(lambda c, x: (c, eval_chunk(x)), None, Xt_chunks)
out = outs.reshape((ntrg_pad,))[:ntrg]
return out.reshape((1, ntrg))
[docs]
def computeB_offsurface_baseline(
X_src,
BdotN,
J,
Xt,
upsample_factor: int = 1,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
ext: bool = True,
):
"""Baseline off-surface evaluation using direct quadrature.
This mirrors ExtVacuumField behavior (no singular correction) with
optional upsampling for improved accuracy.
"""
X_src = jnp.asarray(X_src)
BdotN = jnp.asarray(BdotN)
J = jnp.asarray(J)
nt = X_src.shape[1]
npol = X_src.shape[2]
if upsample_factor > 1:
nt1 = nt * upsample_factor
np1 = npol * upsample_factor
X_src = upsample(X_src, nt, npol, nt1, np1)
BdotN = upsample(BdotN[None, ...], nt, npol, nt1, np1)[0]
J = upsample(J, nt, npol, nt1, np1)
dX = grad2d(X_src, X_src.shape[1], X_src.shape[2])
_, area_elem = surf_normal_area_elem(dX, X_src)
sign = 1.0 if ext else -1.0
gradG = laplace_fxd_u_eval(
X_src,
Xt,
BdotN,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
bs = biotsavart_fx_u_eval(
X_src,
Xt,
J,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
return sign * (gradG - bs)
[docs]
def computeB_offsurface_adaptive(
X_src,
BdotN,
J,
Xt,
digits: int = 5,
max_Nt: int = -1,
max_Np: int = -1,
ext: bool = True,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Adaptive off-surface evaluation matching ExtVacuumField logic."""
X_src, BdotN, J, area_elem = _offsurface_adapt_grid(
X_src,
BdotN,
J,
Xt,
digits=digits,
max_Nt=max_Nt,
max_Np=max_Np,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
sign = 1.0 if ext else -1.0
gradG = laplace_fxd_u_eval(
X_src,
Xt,
BdotN,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
bs = biotsavart_fx_u_eval(
X_src,
Xt,
J,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
return sign * (gradG - bs)
[docs]
def computeB_offsurface_adaptive_schedule(
X_src,
BdotN,
J,
Xt,
*,
levels: tuple[tuple[int, int], ...],
digits: int = 5,
ext: bool = True,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""JIT-friendly adaptive off-surface evaluation with fixed refinement schedule.
The refinement schedule is provided as a static tuple of (Nt, Np) pairs.
Shapes are static per-level, so this function can be JIT-compiled with
``levels`` marked static. The method updates the result only while the
double-layer self-test error exceeds the tolerance.
"""
X_src = jnp.asarray(X_src)
BdotN = jnp.asarray(BdotN)
J = jnp.asarray(J)
Xt = jnp.asarray(Xt)
if len(levels) == 0:
raise ValueError("levels must contain at least one (Nt, Np) pair")
nt0 = int(X_src.shape[1])
np0 = int(X_src.shape[2])
tol = 10.0 ** (-digits)
sign = 1.0 if ext else -1.0
def eval_level(nt, npol):
X_lvl = resample(X_src, nt0, np0, nt, npol)
BdotN_lvl = resample(BdotN[None, ...], nt0, np0, nt, npol)[0]
J_lvl = resample(J, nt0, np0, nt, npol)
dX = grad2d(X_lvl, nt, npol)
normal, area_elem = surf_normal_area_elem(dX, X_lvl)
ones = jnp.ones((nt, npol), dtype=X_lvl.dtype)
U = laplace_dx_u_eval(
X_lvl,
normal,
Xt,
ones,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
U = jnp.asarray(U).reshape(-1)
err = jnp.max(jnp.minimum(jnp.abs(1.0 - U), jnp.abs(U)))
gradG = laplace_fxd_u_eval(
X_lvl,
Xt,
BdotN_lvl,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
bs = biotsavart_fx_u_eval(
X_lvl,
Xt,
J_lvl,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
return sign * (gradG - bs), err
nt_init, np_init = levels[0]
B_best, err_best = eval_level(int(nt_init), int(np_init))
for nt, npol in levels[1:]:
nt_i = int(nt)
np_i = int(npol)
def update(state):
return eval_level(nt_i, np_i)
def keep(state):
return state
B_best, err_best = jax.lax.cond(
err_best > tol,
update,
keep,
operand=(B_best, err_best),
)
return B_best
[docs]
def computeGradB_offsurface_adaptive_schedule(
X_src,
BdotN,
J,
Xt,
*,
levels: tuple[tuple[int, int], ...],
digits: int = 5,
ext: bool = True,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""JIT-friendly adaptive off-surface GradB evaluation with fixed schedule."""
X_src = jnp.asarray(X_src)
BdotN = jnp.asarray(BdotN)
J = jnp.asarray(J)
Xt = jnp.asarray(Xt)
if len(levels) == 0:
raise ValueError("levels must contain at least one (Nt, Np) pair")
nt0 = int(X_src.shape[1])
np0 = int(X_src.shape[2])
tol = 10.0 ** (-digits)
sign = 1.0 if ext else -1.0
def eval_level(nt, npol):
X_lvl = resample(X_src, nt0, np0, nt, npol)
BdotN_lvl = resample(BdotN[None, ...], nt0, np0, nt, npol)[0]
J_lvl = resample(J, nt0, np0, nt, npol)
dX = grad2d(X_lvl, nt, npol)
normal, area_elem = surf_normal_area_elem(dX, X_lvl)
ones = jnp.ones((nt, npol), dtype=X_lvl.dtype)
U = laplace_dx_u_eval(
X_lvl,
normal,
Xt,
ones,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
U = jnp.asarray(U).reshape(-1)
err = jnp.max(jnp.minimum(jnp.abs(1.0 - U), jnp.abs(U)))
gradG_J = laplace_fxd2_u_eval_vec(
X_lvl,
Xt,
J_lvl,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradG_J = jnp.asarray(gradG_J).reshape((3, 3, 3, Xt.shape[1]))
gradgradG_BdotN = laplace_fxd2_u_eval(
X_lvl,
Xt,
BdotN_lvl,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
gradgradG_BdotN = jnp.asarray(gradgradG_BdotN).reshape((3, 3, Xt.shape[1]))
gradB = jnp.zeros((3, 3, Xt.shape[1]), dtype=gradG_J.dtype)
for k in range(3):
k1 = (k + 1) % 3
k2 = (k + 2) % 3
gradB = gradB.at[k].set(gradG_J[k1, k2] - gradG_J[k2, k1])
gradB = gradB + gradgradG_BdotN
return gradB * sign, err
nt_init, np_init = levels[0]
grad_best, err_best = eval_level(int(nt_init), int(np_init))
for nt, npol in levels[1:]:
nt_i = int(nt)
np_i = int(npol)
def update(state):
return eval_level(nt_i, np_i)
def keep(state):
return state
grad_best, err_best = jax.lax.cond(
err_best > tol,
update,
keep,
operand=(grad_best, err_best),
)
return grad_best
def _offsurface_adapt_grid(
X_src,
BdotN,
J,
Xt,
*,
digits: int = 5,
max_Nt: int = -1,
max_Np: int = -1,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
):
"""Upsample source grid until the double-layer self-test meets tolerance."""
X_src = jnp.asarray(X_src)
BdotN = jnp.asarray(BdotN)
J = jnp.asarray(J)
Xt = jnp.asarray(Xt)
nt = X_src.shape[1]
npol = X_src.shape[2]
tol = 10.0 ** (-digits)
while True:
dX = grad2d(X_src, nt, npol)
normal, area_elem = surf_normal_area_elem(dX, X_src)
ones = jnp.ones((nt, npol), dtype=X_src.dtype)
U = laplace_dx_u_eval(
X_src,
normal,
Xt,
ones,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
U = jnp.asarray(U).reshape(-1)
err = jnp.max(jnp.minimum(jnp.abs(1.0 - U), jnp.abs(U)))
if err <= tol:
return X_src, BdotN, J, area_elem
nt2 = nt * 2
np2 = npol * 2
if max_Nt > 0:
nt2 = min(nt2, max_Nt)
if max_Np > 0:
np2 = min(np2, max_Np)
if nt2 == nt and np2 == npol:
return X_src, BdotN, J, area_elem
X_src = upsample(X_src, nt, npol, nt2, np2)
BdotN = upsample(BdotN[None, ...], nt, npol, nt2, np2)[0]
J = upsample(J, nt, npol, nt2, np2)
nt, npol = nt2, np2
def _surface_cond(dX, nt: int, npol: int):
dX = jnp.asarray(dX)
xt = dX[0]
xp = dX[1]
yt = dX[2]
yp = dX[3]
zt = dX[4]
zp = dX[5]
m00 = (xt * xt + yt * yt + zt * zt) / (nt * nt)
m11 = (xp * xp + yp * yp + zp * zp) / (npol * npol)
ratio = jnp.sqrt(m00 / m11)
amin = jnp.min(ratio)
amax = jnp.max(ratio)
return jnp.sqrt(amax / amin)
def _build_patch_indices(t_idx, p_idx, nt: int, npol: int, patch_dim0: int):
patch_dim = 2 * patch_dim0 + 1
dt = jnp.arange(-patch_dim0, patch_dim0 + 1, dtype=jnp.int32)
dp = jnp.arange(-patch_dim0, patch_dim0 + 1, dtype=jnp.int32)
tt = (t_idx[:, None, None] + dt[None, :, None]) % nt
pp = (p_idx[:, None, None] + dp[None, None, :]) % npol
idx = (tt * npol + pp).reshape((t_idx.shape[0], patch_dim * patch_dim)).astype(jnp.int32)
return idx
def _interp_patch(values, precomp):
# values: (dof, Ngrid)
dof = values.shape[0]
idx = precomp.interp_idx.reshape(-1)
if values.dtype != precomp.M_G2P.dtype:
values = values.astype(precomp.M_G2P.dtype)
gathered = jnp.take(values, idx, axis=1)
gathered = gathered.reshape((dof, precomp.npolar, INTERP_ORDER, INTERP_ORDER))
weights = precomp.M_G2P[None, ...]
return jnp.sum(gathered * weights, axis=(2, 3))
def _interp_patch_blocked(values, precomp, block_size: int):
# values: (dof, Ngrid)
dof = values.shape[0]
if values.dtype != precomp.M_G2P.dtype:
values = values.astype(precomp.M_G2P.dtype)
npolar = precomp.npolar
block_size = int(block_size)
if block_size <= 0 or block_size >= npolar:
return _interp_patch(values, precomp)
idx = precomp.interp_idx.reshape((npolar, INTERP_ORDER, INTERP_ORDER))
weights = precomp.M_G2P
pad = (-npolar) % block_size
if pad:
idx = jnp.pad(idx, ((0, pad), (0, 0), (0, 0)))
weights = jnp.pad(weights, ((0, pad), (0, 0), (0, 0)))
npolar_pad = npolar + pad
nblocks = npolar_pad // block_size
idx_blocks = idx.reshape((nblocks, block_size, INTERP_ORDER, INTERP_ORDER))
w_blocks = weights.reshape((nblocks, block_size, INTERP_ORDER, INTERP_ORDER))
def scan_fn(carry, xs):
idx_block, w_block = xs
gathered = jnp.take(values, idx_block, axis=1)
block_out = jnp.sum(gathered * w_block[None, ...], axis=(2, 3))
return carry, block_out
_, blocks = jax.lax.scan(scan_fn, None, (idx_blocks, w_blocks))
blocks = jnp.transpose(blocks, (1, 0, 2))
out = blocks.reshape((dof, npolar_pad))[:, :npolar]
return out
def _resolve_interp_block_size(interp_block_size, npolar: int, mode: str):
if interp_block_size is None:
return None
if isinstance(interp_block_size, str) and interp_block_size.lower() == "auto":
if npolar <= 256:
return None
if mode == "gradb":
return 32
return 64
return int(interp_block_size)
[docs]
def laplace_fxd_u_eval_singular(
X_src,
dX_src,
density,
trg_nt: int,
trg_np: int,
nfp: int,
X_trg=None,
digits: int = 5,
patch_dim0: int | None = None,
rad_dim: int | None = None,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
patch_idx=None,
orient: float | None = None,
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool = False,
):
"""Evaluate Laplace FxdU with singular correction on surface targets."""
X_src = jnp.asarray(X_src)
dX_src = jnp.asarray(dX_src)
density = jnp.asarray(density)
nt = X_src.shape[1]
npol = X_src.shape[2]
if X_trg is None:
X_trg = field_period_target_coords(X_src, trg_nt, trg_np, nfp)
else:
X_trg = jnp.asarray(X_trg)
base = laplace_fxd_u_eval(
X_src,
X_trg,
density,
surf_normal_area_elem(dX_src, X_src)[1],
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if patch_dim0 is None:
cond = _surface_cond(dX_src, nt, npol)
cond_val = float(cond)
patch_dim0 = select_patch_dim(digits, cond_val)
if rad_dim is None:
rad_dim = int(patch_dim0 * 1.6)
precomp = precompute_singular(
patch_dim0,
rad_dim,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
)
patch_dim = precomp.patch_dim
ngrid = precomp.ngrid
skip_nt = nt // (nfp * trg_nt)
skip_np = npol // trg_np
t_idx = jnp.arange(trg_nt) * skip_nt
p_idx = jnp.arange(trg_np) * skip_np
tt, pp = jnp.meshgrid(t_idx, p_idx, indexing="ij")
t_flat = tt.reshape(-1)
p_flat = pp.reshape(-1)
if patch_idx is None:
patch_idx = _build_patch_indices(t_flat, p_flat, nt, npol, patch_dim0)
X_flat = X_src.reshape((3, -1))
dX_flat = dX_src.reshape((6, -1))
dens_flat = density.reshape(-1)
def gather(values, idx):
return jax.vmap(lambda ii: values[:, ii])(idx)
if orient is None:
orient = float(normal_orientation(X_src, surf_normal_area_elem(dX_src, X_src)[0]))
invNt = 1.0 / nt
invNp = 1.0 / npol
def corr_one(Gi, Ggi, GiF, TrgCoord):
# Gi: (3, Ngrid), Ggi: (6, Ngrid)
n0 = Ggi[2] * Ggi[5] - Ggi[3] * Ggi[4]
n1 = Ggi[4] * Ggi[1] - Ggi[5] * Ggi[0]
n2 = Ggi[0] * Ggi[3] - Ggi[1] * Ggi[2]
r = jnp.sqrt(n0 * n0 + n1 * n1 + n2 * n2)
Ga = r * invNt * invNp
# scale gradients
Ggs = Ggi.at[0].multiply(invNt)
Ggs = Ggs.at[2].multiply(invNt)
Ggs = Ggs.at[4].multiply(invNt)
Ggs = Ggs.at[1].multiply(invNp)
Ggs = Ggs.at[3].multiply(invNp)
Ggs = Ggs.at[5].multiply(invNp)
# grid kernel
dx = TrgCoord[None, :] - Gi.T
MGrid = laplace_fxd_u(dx, jnp.ones((ngrid,)))
MGrid = MGrid * (Ga * precomp.Gpou)[:, None]
# polar interpolation
if interp_block_size is None:
P = _interp_patch(Gi, precomp) # (3, Npolar)
Pg = _interp_patch(Ggs, precomp) # (6, Npolar)
else:
P = _interp_patch_blocked(Gi, precomp, interp_block_size)
Pg = _interp_patch_blocked(Ggs, precomp, interp_block_size)
if P.dtype != TrgCoord.dtype:
P = P.astype(TrgCoord.dtype)
if Pg.dtype != TrgCoord.dtype:
Pg = Pg.astype(TrgCoord.dtype)
n0p = Pg[2] * Pg[5] - Pg[3] * Pg[4]
n1p = Pg[4] * Pg[1] - Pg[5] * Pg[0]
n2p = Pg[0] * Pg[3] - Pg[1] * Pg[2]
rp = jnp.sqrt(n0p * n0p + n1p * n1p + n2p * n2p)
dxp = TrgCoord[None, :] - P.T
MPolar = laplace_fxd_u(dxp, jnp.ones((precomp.npolar,)))
MPolar = MPolar * (rp * precomp.Ppou)[:, None]
# scatter polar contributions back to grid
idx = precomp.interp_idx.reshape(-1)
w = precomp.M_G2P.reshape((precomp.npolar, -1))
for k in range(3):
contrib = (MPolar[:, k:k+1] * w).reshape(-1)
MGrid = MGrid.at[idx, k].add(contrib)
return jnp.sum(GiF[:, None] * MGrid, axis=0)
Trg_flat = X_trg.reshape((3, -1)).T
corr_fn = jax.checkpoint(corr_one) if remat else corr_one
interp_block_size = _resolve_interp_block_size(interp_block_size, precomp.npolar, "b")
if target_chunk_size is None or target_chunk_size <= 0:
G = gather(X_flat, patch_idx) # (Ntrg, 3, Ngrid)
Gg = gather(dX_flat, patch_idx) # (Ntrg, 6, Ngrid)
GF = jax.vmap(lambda idx: dens_flat[idx])(patch_idx) # (Ntrg, Ngrid)
corr = jax.vmap(corr_fn)(G, Gg, GF, Trg_flat)
else:
ntrg = Trg_flat.shape[0]
pad = (-ntrg) % target_chunk_size
if pad:
patch_idx = jnp.pad(patch_idx, ((0, pad), (0, 0)))
Trg_flat = jnp.pad(Trg_flat, ((0, pad), (0, 0)))
ntrg_pad = Trg_flat.shape[0]
n_chunks = ntrg_pad // target_chunk_size
patch_chunks = patch_idx.reshape((n_chunks, target_chunk_size, -1))
trg_chunks = Trg_flat.reshape((n_chunks, target_chunk_size, 3))
def scan_fn(carry, xs):
pidx_chunk, trg_chunk = xs
G = gather(X_flat, pidx_chunk)
Gg = gather(dX_flat, pidx_chunk)
GF = jax.vmap(lambda idx: dens_flat[idx])(pidx_chunk)
corr_chunk = jax.vmap(corr_fn)(G, Gg, GF, trg_chunk)
return carry, corr_chunk
_, corr_chunks = jax.lax.scan(scan_fn, None, (patch_chunks, trg_chunks))
corr = corr_chunks.reshape((ntrg_pad, -1))[:ntrg]
corr = corr.T.reshape((3, trg_nt, trg_np))
base = base.reshape((3, trg_nt, trg_np))
return base + corr
[docs]
def laplace_fxd_u_eval_vec_singular(
X_src,
dX_src,
density_vec,
trg_nt: int,
trg_np: int,
nfp: int,
X_trg=None,
digits: int = 5,
patch_dim0: int | None = None,
rad_dim: int | None = None,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
patch_idx=None,
orient: float | None = None,
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool = False,
):
density_vec = jnp.asarray(density_vec)
return jax.vmap(
lambda dens: laplace_fxd_u_eval_singular(
X_src,
dX_src,
dens,
trg_nt,
trg_np,
nfp,
X_trg=X_trg,
digits=digits,
patch_dim0=patch_dim0,
rad_dim=rad_dim,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
),
in_axes=0,
out_axes=0,
)(density_vec)
[docs]
def laplace_dx_u_eval_singular(
X_src,
dX_src,
density,
trg_nt: int,
trg_np: int,
nfp: int,
X_trg=None,
digits: int = 5,
patch_dim0: int | None = None,
rad_dim: int | None = None,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
patch_idx=None,
orient: float | None = None,
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool = False,
):
"""Evaluate Laplace DxU with singular correction on surface targets."""
X_src = jnp.asarray(X_src)
dX_src = jnp.asarray(dX_src)
density = jnp.asarray(density)
nt = X_src.shape[1]
npol = X_src.shape[2]
normal, area_elem = surf_normal_area_elem(dX_src, X_src)
if X_trg is None:
X_trg = field_period_target_coords(X_src, trg_nt, trg_np, nfp)
else:
X_trg = jnp.asarray(X_trg)
base = laplace_dx_u_eval(
X_src,
normal,
X_trg,
density,
area_elem,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if patch_dim0 is None:
cond = _surface_cond(dX_src, nt, npol)
cond_val = float(cond)
patch_dim0 = select_patch_dim(digits, cond_val)
if rad_dim is None:
rad_dim = int(patch_dim0 * 1.6)
precomp = precompute_singular(
patch_dim0,
rad_dim,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
)
patch_dim = precomp.patch_dim
ngrid = precomp.ngrid
skip_nt = nt // (nfp * trg_nt)
skip_np = npol // trg_np
t_idx = jnp.arange(trg_nt) * skip_nt
p_idx = jnp.arange(trg_np) * skip_np
tt, pp = jnp.meshgrid(t_idx, p_idx, indexing="ij")
t_flat = tt.reshape(-1)
p_flat = pp.reshape(-1)
if patch_idx is None:
patch_idx = _build_patch_indices(t_flat, p_flat, nt, npol, patch_dim0)
X_flat = X_src.reshape((3, -1))
dX_flat = dX_src.reshape((6, -1))
dens_flat = density.reshape(-1)
def gather(values, idx):
return jax.vmap(lambda ii: values[:, ii])(idx)
if orient is None:
orient = float(normal_orientation(X_src, normal))
invNt = 1.0 / nt
invNp = 1.0 / npol
def corr_one(Gi, Ggi, GiF, TrgCoord):
n0 = Ggi[2] * Ggi[5] - Ggi[3] * Ggi[4]
n1 = Ggi[4] * Ggi[1] - Ggi[5] * Ggi[0]
n2 = Ggi[0] * Ggi[3] - Ggi[1] * Ggi[2]
r = jnp.sqrt(n0 * n0 + n1 * n1 + n2 * n2)
Ga = r * invNt * invNp
inv_r = 1.0 / r
Gn = jnp.stack([n0, n1, n2], axis=0) * inv_r * orient
# scale gradients
Ggs = Ggi.at[0].multiply(invNt)
Ggs = Ggs.at[2].multiply(invNt)
Ggs = Ggs.at[4].multiply(invNt)
Ggs = Ggs.at[1].multiply(invNp)
Ggs = Ggs.at[3].multiply(invNp)
Ggs = Ggs.at[5].multiply(invNp)
dx = TrgCoord[None, :] - Gi.T
MGrid = laplace_dx_u(dx, Gn.T, jnp.ones((ngrid,)))
MGrid = MGrid * (Ga * precomp.Gpou)
MGrid = MGrid.reshape((ngrid,))
if interp_block_size is None:
P = _interp_patch(Gi, precomp) # (3, Npolar)
Pg = _interp_patch(Ggs, precomp) # (6, Npolar)
else:
P = _interp_patch_blocked(Gi, precomp, interp_block_size)
Pg = _interp_patch_blocked(Ggs, precomp, interp_block_size)
if P.dtype != TrgCoord.dtype:
P = P.astype(TrgCoord.dtype)
if Pg.dtype != TrgCoord.dtype:
Pg = Pg.astype(TrgCoord.dtype)
n0p = Pg[2] * Pg[5] - Pg[3] * Pg[4]
n1p = Pg[4] * Pg[1] - Pg[5] * Pg[0]
n2p = Pg[0] * Pg[3] - Pg[1] * Pg[2]
rp = jnp.sqrt(n0p * n0p + n1p * n1p + n2p * n2p)
inv_rp = 1.0 / rp
Pn = jnp.stack([n0p, n1p, n2p], axis=0) * inv_rp * orient
dxp = TrgCoord[None, :] - P.T
MPolar = laplace_dx_u(dxp, Pn.T, jnp.ones((precomp.npolar,)))
MPolar = MPolar * (rp * precomp.Ppou)
idx = precomp.interp_idx.reshape(-1)
w = precomp.M_G2P.reshape((precomp.npolar, -1))
contrib = (MPolar[:, None] * w).reshape(-1)
MGrid = MGrid.at[idx].add(contrib)
return jnp.sum(GiF * MGrid)
Trg_flat = X_trg.reshape((3, -1)).T
corr_fn = jax.checkpoint(corr_one) if remat else corr_one
interp_block_size = _resolve_interp_block_size(interp_block_size, precomp.npolar, "b")
if target_chunk_size is None or target_chunk_size <= 0:
G = gather(X_flat, patch_idx)
Gg = gather(dX_flat, patch_idx)
GF = jax.vmap(lambda idx: dens_flat[idx])(patch_idx)
corr = jax.vmap(corr_fn)(G, Gg, GF, Trg_flat)
else:
ntrg = Trg_flat.shape[0]
pad = (-ntrg) % target_chunk_size
if pad:
patch_idx = jnp.pad(patch_idx, ((0, pad), (0, 0)))
Trg_flat = jnp.pad(Trg_flat, ((0, pad), (0, 0)))
ntrg_pad = Trg_flat.shape[0]
n_chunks = ntrg_pad // target_chunk_size
patch_chunks = patch_idx.reshape((n_chunks, target_chunk_size, -1))
trg_chunks = Trg_flat.reshape((n_chunks, target_chunk_size, 3))
def scan_fn(carry, xs):
pidx_chunk, trg_chunk = xs
G = gather(X_flat, pidx_chunk)
Gg = gather(dX_flat, pidx_chunk)
GF = jax.vmap(lambda idx: dens_flat[idx])(pidx_chunk)
corr_chunk = jax.vmap(corr_fn)(G, Gg, GF, trg_chunk)
return carry, corr_chunk
_, corr_chunks = jax.lax.scan(scan_fn, None, (patch_chunks, trg_chunks))
corr = corr_chunks.reshape((ntrg_pad,))[:ntrg]
corr = corr.reshape((1, trg_nt, trg_np))
base = base.reshape((1, trg_nt, trg_np))
return base + corr
[docs]
def laplace_fxd2_u_eval_singular(
X_src,
dX_src,
density,
trg_nt: int,
trg_np: int,
nfp: int,
X_trg=None,
digits: int = 5,
patch_dim0: int | None = None,
rad_dim: int | None = None,
hedgehog_order: int = 8,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
patch_idx=None,
orient: float | None = None,
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool = False,
scan_targets: bool = False,
):
"""Evaluate Laplace Fxd2U with singular correction (Hedgehog).
Args:
scan_targets: If True, use a ``lax.scan`` loop over targets instead
of ``vmap`` in the singular correction. This can reduce peak
memory by avoiding large broadcasted temporaries, at the cost
of lower parallelism per chunk.
"""
X_src = jnp.asarray(X_src)
dX_src = jnp.asarray(dX_src)
density = jnp.asarray(density)
nt = X_src.shape[1]
npol = X_src.shape[2]
if X_trg is None:
X_trg = field_period_target_coords(X_src, trg_nt, trg_np, nfp)
else:
X_trg = jnp.asarray(X_trg)
base = laplace_fxd2_u_eval(
X_src,
X_trg,
density,
surf_normal_area_elem(dX_src, X_src)[1],
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
)
if patch_dim0 is None:
cond = _surface_cond(dX_src, nt, npol)
cond_val = float(cond)
patch_dim0 = select_patch_dim(digits, cond_val)
if rad_dim is None:
rad_dim = int(patch_dim0 * 1.6)
precomp = precompute_singular(
patch_dim0,
rad_dim,
hedgehog_order,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
)
patch_dim = precomp.patch_dim
ngrid = precomp.ngrid
skip_nt = nt // (nfp * trg_nt)
skip_np = npol // trg_np
t_idx = jnp.arange(trg_nt) * skip_nt
p_idx = jnp.arange(trg_np) * skip_np
tt, pp = jnp.meshgrid(t_idx, p_idx, indexing="ij")
t_flat = tt.reshape(-1)
p_flat = pp.reshape(-1)
if patch_idx is None:
patch_idx = _build_patch_indices(t_flat, p_flat, nt, npol, patch_dim0)
X_flat = X_src.reshape((3, -1))
dX_flat = dX_src.reshape((6, -1))
dens_flat = density.reshape(-1)
def gather(values, idx):
return jax.vmap(lambda ii: values[:, ii])(idx)
if orient is None:
orient = float(normal_orientation(X_src, surf_normal_area_elem(dX_src, X_src)[0]))
invNt = 1.0 / nt
invNp = 1.0 / npol
interp_nds = jnp.arange(1, 17, dtype=X_src.dtype)
interp_nds = interp_nds[:hedgehog_order]
def corr_one(Gi, Ggi, GiF, TrgCoord):
# Gi: (3, Ngrid), Ggi: (6, Ngrid)
n0 = Ggi[2] * Ggi[5] - Ggi[3] * Ggi[4]
n1 = Ggi[4] * Ggi[1] - Ggi[5] * Ggi[0]
n2 = Ggi[0] * Ggi[3] - Ggi[1] * Ggi[2]
r = jnp.sqrt(n0 * n0 + n1 * n1 + n2 * n2)
Ga = r * invNt * invNp
# scale gradients
Ggs = Ggi.at[0].multiply(invNt)
Ggs = Ggs.at[2].multiply(invNt)
Ggs = Ggs.at[4].multiply(invNt)
Ggs = Ggs.at[1].multiply(invNp)
Ggs = Ggs.at[3].multiply(invNp)
Ggs = Ggs.at[5].multiply(invNp)
# grid kernel
dx = TrgCoord[None, :] - Gi.T
MGrid = laplace_fxd2_u(dx, jnp.ones((ngrid,)))
MGrid = MGrid * (Ga * precomp.Gpou)[:, None, None]
MGrid = MGrid.reshape((ngrid, 9))
# polar interpolation
if interp_block_size is None:
P = _interp_patch(Gi, precomp) # (3, Npolar)
Pg = _interp_patch(Ggs, precomp) # (6, Npolar)
else:
P = _interp_patch_blocked(Gi, precomp, interp_block_size)
Pg = _interp_patch_blocked(Ggs, precomp, interp_block_size)
if P.dtype != TrgCoord.dtype:
P = P.astype(TrgCoord.dtype)
if Pg.dtype != TrgCoord.dtype:
Pg = Pg.astype(TrgCoord.dtype)
n0p = Pg[2] * Pg[5] - Pg[3] * Pg[4]
n1p = Pg[4] * Pg[1] - Pg[5] * Pg[0]
n2p = Pg[0] * Pg[3] - Pg[1] * Pg[2]
rp = jnp.sqrt(n0p * n0p + n1p * n1p + n2p * n2p)
# hedgehog target coordinates
ntrg0 = Ggi[2, patch_dim0 * patch_dim + patch_dim0] * Ggi[5, patch_dim0 * patch_dim + patch_dim0] - Ggi[3, patch_dim0 * patch_dim + patch_dim0] * Ggi[4, patch_dim0 * patch_dim + patch_dim0]
ntrg1 = Ggi[4, patch_dim0 * patch_dim + patch_dim0] * Ggi[1, patch_dim0 * patch_dim + patch_dim0] - Ggi[5, patch_dim0 * patch_dim + patch_dim0] * Ggi[0, patch_dim0 * patch_dim + patch_dim0]
ntrg2 = Ggi[0, patch_dim0 * patch_dim + patch_dim0] * Ggi[3, patch_dim0 * patch_dim + patch_dim0] - Ggi[1, patch_dim0 * patch_dim + patch_dim0] * Ggi[2, patch_dim0 * patch_dim + patch_dim0]
rtrg = jnp.sqrt(ntrg0 * ntrg0 + ntrg1 * ntrg1 + ntrg2 * ntrg2)
scal = jnp.sqrt(rtrg * invNt * invNp) * orient / rtrg * (-20.0 / precomp.rad_dim)
nvec = jnp.array([ntrg0, ntrg1, ntrg2]) * scal
TrgCoordPolar = TrgCoord[None, :] + interp_nds[:, None] * nvec[None, :]
dxp = TrgCoordPolar[None, :, :] - P.T[:, None, :]
MPolar = laplace_fxd2_u(dxp, jnp.ones((precomp.npolar, hedgehog_order)))
MPolar = MPolar * (rp * precomp.Ppou)[:, None, None, None]
MPolar = MPolar.reshape((precomp.npolar, hedgehog_order, 9))
MPolar = jnp.tensordot(MPolar, precomp.hedgehog_wts, axes=(1, 0))
# scatter polar contributions back to grid
idx = precomp.interp_idx.reshape(-1)
w = precomp.M_G2P.reshape((precomp.npolar, -1))
for k in range(9):
contrib = (MPolar[:, k:k+1] * w).reshape(-1)
MGrid = MGrid.at[idx, k].add(contrib)
return jnp.sum(GiF[:, None] * MGrid, axis=0)
Trg_flat = X_trg.reshape((3, -1)).T
corr_fn = jax.checkpoint(corr_one) if remat else corr_one
interp_block_size = _resolve_interp_block_size(interp_block_size, precomp.npolar, "gradb")
if target_chunk_size is None or target_chunk_size <= 0:
G = gather(X_flat, patch_idx)
Gg = gather(dX_flat, patch_idx)
GF = jax.vmap(lambda idx: dens_flat[idx])(patch_idx)
if scan_targets:
def scan_target(carry, xs):
Gi, Ggi, GiF, TrgCoord = xs
return carry, corr_fn(Gi, Ggi, GiF, TrgCoord)
_, corr = jax.lax.scan(scan_target, None, (G, Gg, GF, Trg_flat))
else:
corr = jax.vmap(corr_fn)(G, Gg, GF, Trg_flat)
else:
ntrg = Trg_flat.shape[0]
pad = (-ntrg) % target_chunk_size
if pad:
patch_idx = jnp.pad(patch_idx, ((0, pad), (0, 0)))
Trg_flat = jnp.pad(Trg_flat, ((0, pad), (0, 0)))
ntrg_pad = Trg_flat.shape[0]
n_chunks = ntrg_pad // target_chunk_size
patch_chunks = patch_idx.reshape((n_chunks, target_chunk_size, -1))
trg_chunks = Trg_flat.reshape((n_chunks, target_chunk_size, 3))
def scan_fn(carry, xs):
pidx_chunk, trg_chunk = xs
G = gather(X_flat, pidx_chunk)
Gg = gather(dX_flat, pidx_chunk)
GF = jax.vmap(lambda idx: dens_flat[idx])(pidx_chunk)
if scan_targets:
def scan_target(carry_inner, xs_inner):
Gi, Ggi, GiF, TrgCoord = xs_inner
return carry_inner, corr_fn(Gi, Ggi, GiF, TrgCoord)
_, corr_chunk = jax.lax.scan(scan_target, None, (G, Gg, GF, trg_chunk))
else:
corr_chunk = jax.vmap(corr_fn)(G, Gg, GF, trg_chunk)
return carry, corr_chunk
_, corr_chunks = jax.lax.scan(scan_fn, None, (patch_chunks, trg_chunks))
corr = corr_chunks.reshape((ntrg_pad, -1))[:ntrg]
corr = corr.T.reshape((9, trg_nt, trg_np))
base = base.reshape((9, trg_nt, trg_np))
return base + corr
[docs]
def laplace_fxd2_u_eval_vec_singular(
X_src,
dX_src,
density_vec,
trg_nt: int,
trg_np: int,
nfp: int,
X_trg=None,
digits: int = 5,
patch_dim0: int | None = None,
rad_dim: int | None = None,
hedgehog_order: int = 8,
chunk_size: int = 1024,
target_chunk_size: int | None = None,
patch_idx=None,
orient: float | None = None,
pou_dtype=None,
patch_dtype=None,
interp_block_size: int | str | None = "auto",
remat: bool = False,
scan_targets: bool = False,
):
"""Vector-density wrapper for Laplace Fxd2U with singular correction.
Args:
scan_targets: forwarded to ``laplace_fxd2_u_eval_singular``.
"""
density_vec = jnp.asarray(density_vec)
return jax.vmap(
lambda dens: laplace_fxd2_u_eval_singular(
X_src,
dX_src,
dens,
trg_nt,
trg_np,
nfp,
X_trg=X_trg,
digits=digits,
patch_dim0=patch_dim0,
rad_dim=rad_dim,
hedgehog_order=hedgehog_order,
chunk_size=chunk_size,
target_chunk_size=target_chunk_size,
patch_idx=patch_idx,
orient=orient,
pou_dtype=pou_dtype,
patch_dtype=patch_dtype,
interp_block_size=interp_block_size,
remat=remat,
scan_targets=scan_targets,
),
in_axes=0,
out_axes=0,
)(density_vec)