Source code for virtual_casing_jax.kernels

"""Kernel functions matching BIEST scaling and conventions."""
from __future__ import annotations

import jax.numpy as jnp

FOUR_PI = 4.0 * jnp.pi


def _safe_rinv(r2, eps=1e-30):
    return jnp.where(r2 > eps, 1.0 / jnp.sqrt(r2), 0.0)


[docs] def laplace_fx_u(dx, f): """Laplace single-layer potential: f / (4*pi*r). dx: (..., 3) f: (...,) returns (...,) """ r2 = jnp.sum(dx * dx, axis=-1) rinv = _safe_rinv(r2) return f * rinv / FOUR_PI
[docs] def laplace_fxd_u(dx, f): """Gradient of Laplace single-layer: -(dx * f) / (4*pi*r^3). dx: (..., 3) f: (...,) returns (..., 3) """ r2 = jnp.sum(dx * dx, axis=-1) rinv = _safe_rinv(r2) rinv3 = rinv * rinv * rinv return -(dx * f[..., None]) * rinv3[..., None] / FOUR_PI
[docs] def laplace_fxd2_u(dx, f): """Second derivatives of Laplace single-layer. Returns a (..., 3, 3) tensor with entries: (-delta_ij * r^-3 + 3 r_i r_j r^-5) * f / (4*pi) """ r2 = jnp.sum(dx * dx, axis=-1) rinv = _safe_rinv(r2) rinv2 = rinv * rinv rinv3 = rinv * rinv2 rinv5 = rinv3 * rinv2 # Outer product r_i r_j r_outer = dx[..., :, None] * dx[..., None, :] eye = jnp.eye(3, dtype=dx.dtype) u = -eye * rinv3[..., None, None] + 3.0 * r_outer * rinv5[..., None, None] return u * f[..., None, None] / FOUR_PI
[docs] def laplace_dx_u(dx, n, f): """Laplace double-layer kernel: (-(n·dx) * f) / (4*pi*r^3). dx: (..., 3) n: (..., 3) source normals f: (...,) returns (...,) """ r2 = jnp.sum(dx * dx, axis=-1) rinv = _safe_rinv(r2) rinv3 = rinv * rinv * rinv ndotr = -jnp.sum(n * dx, axis=-1) return f * ndotr * rinv3 / FOUR_PI
[docs] def biotsavart_fx_u(dx, fvec): """Biot-Savart kernel: (f x dx) / (4*pi*r^3). dx: (..., 3) fvec: (..., 3) returns (..., 3) """ r2 = jnp.sum(dx * dx, axis=-1) rinv = _safe_rinv(r2) rinv3 = rinv * rinv * rinv return jnp.cross(fvec, dx) * rinv3[..., None] / FOUR_PI
[docs] def biotsavart_fxd_u(dx, fvec): """Derivative of Biot-Savart kernel (matches BIEST FxdU). Returns (..., 3, 3) tensor. The explicit formula matches the uker_FxdU in BIEST (see kernel.hpp). """ r2 = jnp.sum(dx * dx, axis=-1) rinv = _safe_rinv(r2) rinv2 = rinv * rinv rinv3 = rinv * rinv2 rinv5 = rinv3 * rinv2 x = dx[..., 0] y = dx[..., 1] z = dx[..., 2] z0 = jnp.zeros_like(x) u00 = z0 u10 = 3.0 * z * x * rinv5 u20 = -3.0 * y * x * rinv5 u01 = z0 u11 = 3.0 * z * y * rinv5 u21 = rinv3 - 3.0 * y * y * rinv5 u02 = z0 u12 = -rinv3 + 3.0 * z * z * rinv5 u22 = -3.0 * y * z * rinv5 u03 = -3.0 * z * x * rinv5 u13 = z0 u23 = -rinv3 + 3.0 * x * x * rinv5 u04 = -3.0 * z * y * rinv5 u14 = z0 u24 = 3.0 * x * y * rinv5 u05 = rinv3 - 3.0 * z * z * rinv5 u15 = z0 u25 = 3.0 * x * z * rinv5 u06 = 3.0 * y * x * rinv5 u16 = rinv3 - 3.0 * x * x * rinv5 u26 = z0 u07 = -rinv3 + 3.0 * y * y * rinv5 u17 = -3.0 * x * y * rinv5 u27 = z0 u08 = 3.0 * y * z * rinv5 u18 = -3.0 * x * z * rinv5 u28 = z0 # u is (3,9) with rows 0,1,2 u = jnp.stack( [ jnp.stack([u00, u01, u02, u03, u04, u05, u06, u07, u08], axis=-1), jnp.stack([u10, u11, u12, u13, u14, u15, u16, u17, u18], axis=-1), jnp.stack([u20, u21, u22, u23, u24, u25, u26, u27, u28], axis=-1), ], axis=-2, ) fx = fvec[..., 0] fy = fvec[..., 1] fz = fvec[..., 2] v0 = -( u[..., 0, :] * fx[..., None] + u[..., 1, :] * fy[..., None] + u[..., 2, :] * fz[..., None] ) # v0 shape (..., 9) -> reshape to (..., 3, 3) v = v0.reshape(dx.shape[:-1] + (3, 3)) return v / FOUR_PI