Source code for virtual_casing_jax.singular_quadrature

"""Singular quadrature utilities (POU + polar correction)."""
from __future__ import annotations

from dataclasses import dataclass
import functools
import math
import numpy as np
import jax.numpy as jnp


INTERP_ORDER = 12


[docs] @dataclass(frozen=True) class SingularPrecomp: patch_dim0: int hedgehog_order: int rad_dim_base: int rad_dim: int ang_dim: int patch_dim: int ngrid: int npolar: int qx: jnp.ndarray qw: jnp.ndarray Gpou: jnp.ndarray Ppou: jnp.ndarray I_G2P: jnp.ndarray M_G2P: jnp.ndarray interp_idx: jnp.ndarray hedgehog_wts: jnp.ndarray
def _legpoly_and_deriv(x, degree: int): if degree == 0: return np.ones_like(x), np.zeros_like(x) if degree == 1: return x.copy(), np.ones_like(x) p0 = np.ones_like(x) p1 = x.copy() dp0 = np.zeros_like(x) dp1 = np.ones_like(x) for n in range(2, degree + 1): scal0 = -(n - 1) / n scal1 = (2 * n - 1) / n p = scal1 * x * p1 + scal0 * p0 dp = scal1 * (p1 + x * dp1) + scal0 * dp0 p0, p1 = p1, p dp0, dp1 = dp1, dp return p1, dp1 def _legendre_rule_01(order: int): """Gauss-Legendre nodes/weights on [0, 1] matching SCTL.""" x = np.empty(order, dtype=np.float64) for i in range(order): x[i] = -( 1 - 1.0 / (8 * order * order) + 1.0 / (8 * order * order * order) ) * math.cos(math.pi * (4 * i + 3) / (4 * order + 2)) for _ in range(5): p, dp = _legpoly_and_deriv(x, order) dx = p / dp x = x - dx if np.max(np.abs(dx)) < np.finfo(np.float64).eps: break nds = 0.5 * (x + 1.0) _, dp = _legpoly_and_deriv(x, order) wts = 1.0 / ((1.0 - x * x) * (dp * dp)) return nds, wts def _pou_fn(patch_dim: int): if patch_dim > 45: power = 10 elif patch_dim > 20: power = 8 else: power = 6 def pou(r): if r < 0: return 1.0 return math.exp(-36.0 * (r ** power)) return pou def _lagrange_interp(z0, z1, i0, i1): h = 1.0 / (INTERP_ORDER - 1) p = 1.0 z0i = i0 * h z1i = i1 * h for j0 in range(INTERP_ORDER): if j0 != i0: y0 = j0 * h p *= (z0 - y0) / (z0i - y0) for j1 in range(INTERP_ORDER): if j1 != i1: y1 = j1 * h p *= (z1 - y1) / (z1i - y1) return p
[docs] @functools.lru_cache(maxsize=None) def precompute_singular( patch_dim0: int, rad_dim: int, hedgehog_order: int = 1, pou_dtype=None, patch_dtype=None, index_dtype=None, ): if pou_dtype is not None: pou_dtype = np.dtype(pou_dtype) if patch_dtype is not None: patch_dtype = np.dtype(patch_dtype) if index_dtype is None: index_dtype = np.int32 else: index_dtype = np.dtype(index_dtype) patch_dim = 2 * patch_dim0 + 1 rad_dim_base = rad_dim rad_dim = rad_dim_base * (3 if hedgehog_order > 1 else 1) ang_dim = 2 * rad_dim_base ngrid = patch_dim * patch_dim npolar = rad_dim * ang_dim patch_rad = (patch_dim - 1) // 2 qx, qw = _legendre_rule_01(rad_dim) if hedgehog_order > 1: qw = qw * (2.0 * qx) qx = qx * qx pou = _pou_fn(patch_dim) # Gpou on grid Gpou = np.zeros(ngrid, dtype=np.float64) h = 1.0 / patch_rad for i in range(patch_dim): for j in range(patch_dim): dr0 = (i - patch_rad) * h dr1 = (j - patch_rad) * h r = math.sqrt(dr0 * dr0 + dr1 * dr1) Gpou[i * patch_dim + j] = -pou(r) # Ppou on polar grid Ppou = np.zeros(npolar, dtype=np.float64) dt = 2.0 * math.pi / ang_dim for i in range(rad_dim): for j in range(ang_dim): dr = qw[i] * patch_rad rdt = qx[i] * patch_rad * dt Ppou[i * ang_dim + j] = pou(qx[i]) * dr * rdt # Interpolation map I_G2P = np.zeros(npolar, dtype=index_dtype) M_G2P = np.zeros((npolar, INTERP_ORDER, INTERP_ORDER), dtype=np.float64) h_ang = 2.0 * math.pi / ang_dim h_int = 1.0 / (INTERP_ORDER - 1) for i0 in range(rad_dim): for i1 in range(ang_dim): x0 = 0.5 + 0.5 * qx[i0] * math.cos(h_ang * i1) x1 = 0.5 + 0.5 * qx[i0] * math.sin(h_ang * i1) y0 = int(x0 * (patch_dim - 1) - (INTERP_ORDER - 1) / 2) y1 = int(x1 * (patch_dim - 1) - (INTERP_ORDER - 1) / 2) y0 = max(0, min(y0, patch_dim - INTERP_ORDER)) y1 = max(0, min(y1, patch_dim - INTERP_ORDER)) z0 = (x0 * (patch_dim - 1) - y0) * h_int z1 = (x1 * (patch_dim - 1) - y1) * h_int idx = i0 * ang_dim + i1 I_G2P[idx] = y0 * patch_dim + y1 for j0 in range(INTERP_ORDER): for j1 in range(INTERP_ORDER): M_G2P[idx, j0, j1] = _lagrange_interp(z0, z1, j0, j1) # Precompute interpolation indices ii = np.arange(INTERP_ORDER)[:, None] jj = np.arange(INTERP_ORDER)[None, :] interp_idx = I_G2P[:, None, None] + ii * patch_dim + jj # Hedgehog weights if hedgehog_order > 1: interp_nds = np.arange(1, 17, dtype=np.float64) wts = np.zeros(hedgehog_order, dtype=np.float64) for k in range(hedgehog_order): pn = 1.0 pd = 1.0 for i in range(hedgehog_order): if i != k: pn *= interp_nds[i] pd *= (interp_nds[i] - interp_nds[k]) wts[k] = pn / pd else: wts = np.ones(1, dtype=np.float64) def _cast_pou(arr): if pou_dtype is None: return jnp.asarray(arr) return jnp.asarray(arr, dtype=pou_dtype) def _cast_patch(arr): if patch_dtype is None: return jnp.asarray(arr) return jnp.asarray(arr, dtype=patch_dtype) return SingularPrecomp( patch_dim0=patch_dim0, hedgehog_order=hedgehog_order, rad_dim_base=rad_dim_base, rad_dim=rad_dim, ang_dim=ang_dim, patch_dim=patch_dim, ngrid=ngrid, npolar=npolar, qx=jnp.asarray(qx), qw=jnp.asarray(qw), Gpou=_cast_pou(Gpou), Ppou=_cast_pou(Ppou), I_G2P=jnp.asarray(I_G2P, dtype=jnp.dtype(index_dtype)), M_G2P=_cast_patch(M_G2P), interp_idx=jnp.asarray(interp_idx, dtype=jnp.dtype(index_dtype)), hedgehog_wts=_cast_pou(wts), )
[docs] def select_patch_dim(digits: int, cond: float): p = int(digits * cond * 1.6) for thresh in [64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8]: if p >= thresh: return thresh return 6