"""JAX surface operators matching BIEST conventions."""
from __future__ import annotations
import jax
import jax.numpy as jnp
TWOPI = 2.0 * jnp.pi
[docs]
def fft_r2c(x, nt: int, npol: int):
"""Unitary r2c FFT over (nt, npol) axes."""
x = jnp.asarray(x)
return jnp.fft.rfftn(x, axes=(-2, -1), norm="ortho")
[docs]
def fft_c2r(y, nt: int, npol: int):
"""Unitary c2r FFT over (nt, npol) axes."""
y = jnp.asarray(y)
return jnp.fft.irfftn(y, s=(nt, npol), axes=(-2, -1), norm="ortho")
[docs]
def rotate_toroidal(X, nt: int, npol: int, dtheta):
"""Rotate field in toroidal angle by dtheta.
X shape: (dof, nt, npol)
"""
if dtheta == 0 or nt == 0:
return X
X = jnp.asarray(X)
coeff = fft_r2c(X, nt, npol)
# Match BIEST frequency indexing: t - (t > Nt/2 ? Nt : 0)
m = jnp.arange(nt)
m = jnp.where(m > (nt // 2), m - nt, m)
phase = jnp.exp(1j * m[:, None] * dtheta)
coeff = coeff * phase[None, :, :]
return fft_c2r(coeff, nt, npol)
[docs]
def complete_vec_field(Y, is_surf: bool, half_period: bool, nfp: int, nt: int, npol: int, dtheta: float):
"""Match BIEST SurfaceOp::CompleteVecField.
Y shape: (dof, nt, npol)
Returns X shape: (dof, nfp*nt, npol)
"""
Y = jnp.asarray(Y)
dof = int(Y.shape[0])
if half_period:
# Build doubled toroidal grid with stellarator symmetry.
t_idx = jnp.arange(nt)
p_idx = jnp.arange(npol)
t_mirror = nt - t_idx - 1
p_mirror = (-p_idx) % npol
Y_mirror = Y[:, t_mirror[:, None], p_mirror[None, :]]
if dof == 3:
sign = 1.0 if is_surf else -1.0
cos_theta = jnp.cos(TWOPI / nfp)
sin_theta = jnp.sin(TWOPI / nfp)
x = Y_mirror[0] * sign
y = -Y_mirror[1] * sign
z = -Y_mirror[2] * sign
x2 = x * cos_theta - y * sin_theta
y2 = x * sin_theta + y * cos_theta
z2 = z
Y_tail = jnp.stack([x2, y2, z2], axis=0)
else:
Y_tail = Y_mirror
Y = jnp.concatenate([Y, Y_tail], axis=1)
nt = nt * 2
half_period = False
# Replicate for NFP field periods.
if nfp <= 0:
raise ValueError("nfp must be positive")
if dof == 3:
j = jnp.arange(nfp, dtype=Y.dtype)
cost = jnp.cos(TWOPI * j / nfp)
sint = jnp.sin(TWOPI * j / nfp)
x0 = Y[0]
y0 = Y[1]
z0 = Y[2]
x = cost[:, None, None] * x0[None, :, :] - sint[:, None, None] * y0[None, :, :]
y = sint[:, None, None] * x0[None, :, :] + cost[:, None, None] * y0[None, :, :]
z = jnp.broadcast_to(z0[None, :, :], x.shape)
x = x.reshape((nfp * nt, npol))
y = y.reshape((nfp * nt, npol))
z = z.reshape((nfp * nt, npol))
X = jnp.stack([x, y, z], axis=0)
else:
X = jnp.tile(Y, (1, nfp, 1))
if dtheta != 0:
X = rotate_toroidal(X, nfp * nt, npol, dtheta)
return X
[docs]
def upsample(X0, nt0: int, np0: int, nt1: int, np1: int):
"""Upsample using Fourier zero-padding (BIEST SurfaceOp::Upsample)."""
X0 = jnp.asarray(X0)
dof = int(X0.shape[0])
coeff0 = fft_r2c(X0, nt0, np0)
nt0_ = nt0
np0_ = np0 // 2 + 1
nt1_ = nt1
np1_ = np1 // 2 + 1
coeff1 = jnp.zeros((dof, nt1_, np1_), dtype=coeff0.dtype)
scale = jnp.sqrt(jnp.asarray(nt1 * np1, dtype=coeff0.real.dtype)) / jnp.sqrt(jnp.asarray(nt0 * np0, dtype=coeff0.real.dtype))
ntt = min(nt0_, nt1_)
npp = min(np0_, np1_)
t_pos = jnp.arange(0, ntt // 2 + 1)
t_neg = jnp.arange(0, ntt // 2)
p_idx = jnp.arange(0, npp)
scale_t_pos = jnp.ones_like(t_pos, dtype=coeff0.real.dtype)
scale_t_neg = jnp.ones_like(t_neg, dtype=coeff0.real.dtype)
scale_p = jnp.ones_like(p_idx, dtype=coeff0.real.dtype)
if (nt0 % 2 == 0) and (nt0_ < nt1_) and (ntt // 2 < t_pos.size):
scale_t_pos = scale_t_pos.at[ntt // 2].set(0.5)
if (nt1 % 2 == 0) and (nt1_ < nt0_) and (ntt // 2 < t_pos.size):
scale_t_pos = scale_t_pos.at[ntt // 2].set(2.0)
if (nt0 % 2 == 0) and (nt0_ < nt1_) and (ntt // 2 - 1 < t_neg.size) and (ntt // 2 - 1 >= 0):
scale_t_neg = scale_t_neg.at[ntt // 2 - 1].set(0.5)
if (nt1 % 2 == 0) and (nt1_ < nt0_) and (ntt // 2 - 1 < t_neg.size) and (ntt // 2 - 1 >= 0):
scale_t_neg = scale_t_neg.at[ntt // 2 - 1].set(2.0)
if (np0 % 2 == 0) and (np0_ < np1_) and (npp - 1 >= 0):
scale_p = scale_p.at[npp - 1].set(0.5)
if (np1 % 2 == 0) and (np1_ < np0_) and (npp - 1 >= 0):
scale_p = scale_p.at[npp - 1].set(2.0)
# Positive frequencies
coeff1 = coeff1.at[:, t_pos[:, None], p_idx[None, :]].set(
coeff0[:, t_pos[:, None], p_idx[None, :]] * (scale * scale_t_pos[:, None] * scale_p[None, :])[None, :, :]
)
# Negative frequencies (toroidal)
if t_neg.size > 0:
coeff1 = coeff1.at[:, (nt1_ - t_neg - 1)[:, None], p_idx[None, :]].set(
coeff0[:, (nt0_ - t_neg - 1)[:, None], p_idx[None, :]] * (scale * scale_t_neg[:, None] * scale_p[None, :])[None, :, :]
)
X1 = fft_c2r(coeff1, nt1, np1)
# Floating-point correction for integer upsample ratios
ut = nt1 // nt0
up = np1 // np0
if nt1 == nt0 * ut and np1 == np0 * up:
t_idx = jnp.arange(nt0)
p_idx = jnp.arange(np0)
tt = (t_idx * ut)[:, None]
pp = (p_idx * up)[None, :]
X1 = X1.at[:, tt, pp].set(X0[:, t_idx[:, None], p_idx[None, :]])
return X1
[docs]
def resample(X0, nt0: int, np0: int, nt1: int, np1: int):
"""Resample using upsample + decimation (BIEST SurfaceOp::Resample)."""
import math
skip_tor = int(math.ceil(nt0 / float(nt1)))
skip_pol = int(math.ceil(np0 / float(np1)))
X_up = upsample(X0, nt0, np0, nt1 * skip_tor, np1 * skip_pol)
# Decimate
X1 = X_up[:, ::skip_tor, ::skip_pol]
return X1
[docs]
def grad2d(X, nt: int, npol: int):
"""Spectral surface derivatives (BIEST SurfaceOp::Grad2D).
Returns dX with shape (dof * 2, nt, npol) where entries are ordered
as [dX_t, dX_p] per component.
"""
X = jnp.asarray(X)
dof = int(X.shape[0])
coeff = fft_r2c(X, nt, npol)
t = jnp.arange(nt, dtype=coeff.real.dtype)
k_t = jnp.where(t > (nt // 2), t - nt, t)
coeff_t = coeff * (-1j * TWOPI) * k_t[None, :, None]
dX_t = fft_c2r(coeff_t, nt, npol)
p = jnp.arange(npol // 2 + 1, dtype=coeff.real.dtype)
coeff_p = coeff * (-1j * TWOPI) * p[None, None, :]
dX_p = fft_c2r(coeff_p, nt, npol)
dX = jnp.zeros((dof * 2, nt, npol), dtype=X.dtype)
dX = dX.at[0::2].set(dX_t)
dX = dX.at[1::2].set(dX_p)
return dX
[docs]
def surf_normal_area_elem(dX, X=None, *, return_orientation: bool = False):
"""Compute unit normal and area element (BIEST SurfNormalAreaElem).
dX: (6, nt, npol) for 3D surfaces (dX_t, dX_p per component).
X: optional (3, nt, npol) coordinates for orientation.
Returns (normal, area_elem) or (normal, area_elem, orient) when
``return_orientation=True``.
"""
dX = jnp.asarray(dX)
nt = dX.shape[1]
npol = dX.shape[2]
n = nt * npol
xt = jnp.stack([dX[0], dX[2], dX[4]], axis=0)
xp = jnp.stack([dX[1], dX[3], dX[5]], axis=0)
cross = jnp.stack(
[
xt[1] * xp[2] - xp[1] * xt[2],
xt[2] * xp[0] - xp[2] * xt[0],
xt[0] * xp[1] - xp[0] * xt[1],
],
axis=0,
)
area = jnp.sqrt(jnp.sum(cross * cross, axis=0))
normal = cross / area
area_elem = area / float(n)
orient = 1.0
if X is not None:
orient = normal_orientation(X, normal)
normal = normal * orient
if return_orientation:
return normal, area_elem, orient
return normal, area_elem
[docs]
def normal_orientation(X, normal):
"""Return +1 or -1 orientation used by BIEST for normals."""
X = jnp.asarray(X)
normal = jnp.asarray(normal)
# Match BIEST: pick the maximum x-coordinate and compare the x-normal.
x_flat = X[0].reshape(-1)
n_flat = normal[0].reshape(-1)
idx = jnp.argmax(x_flat)
return jnp.where(n_flat[idx] < 0, -1.0, 1.0)
[docs]
def dot_prod(A, B):
"""SoA dot product: A,B shape (3, nt, npol) -> (nt, npol)."""
A = jnp.asarray(A)
B = jnp.asarray(B)
return A[0] * B[0] + A[1] * B[1] + A[2] * B[2]
[docs]
def cross_prod(A, B):
"""SoA cross product: A,B shape (3, nt, npol) -> (3, nt, npol)."""
A = jnp.asarray(A)
B = jnp.asarray(B)
return jnp.stack(
[
A[1] * B[2] - B[1] * A[2],
A[2] * B[0] - B[2] * A[0],
A[0] * B[1] - B[0] * A[1],
],
axis=0,
)