Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions examples/klein_gordon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Klein-Gordon wave equation example.

Solves the Klein-Gordon equation in the frequency domain using the
`klein_gordon` operator provided by j-Wave.

The Klein-Gordon equation for a massive scalar field reads::

(nabla^2 + omega^2/c^2 - m^2) u = f

where *m* is the field mass. Setting *m = 0* recovers the ordinary
Helmholtz equation.

This example sets up a 2-D domain with a point source, solves the
equation for two mass values (massless and massive), and plots the
resulting pressure fields side by side.
Comment on lines +3 to +15
"""

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxdf.discretization import FourierSeries

from jwave.acoustics import klein_gordon
from jwave.geometry import Domain, Medium

# ── Domain setup ──────────────────────────────────────────────────────
N = (128, 128)
dx = (0.1e-3, 0.1e-3) # 0.1 mm grid spacing
domain = Domain(N, dx)

# Homogeneous medium (water-like)
sound_speed = 1500.0 # m/s
medium = Medium(domain=domain, sound_speed=sound_speed, pml_size=15)

# Angular frequency corresponding to 1 MHz
f0 = 1e6
omega = 2 * jnp.pi * f0

# ── Source term ───────────────────────────────────────────────────────
src = jnp.zeros(N)
src = src.at[N[0] // 2, N[1] // 2].set(1.0)
source = FourierSeries(src, domain)

# ── Solve for two mass values ─────────────────────────────────────────
params = klein_gordon.default_params(source, medium, omega=omega)

# Massless (equivalent to Helmholtz)
kg_massless = klein_gordon(source, medium, omega=omega, mass=0.0, params=params)

# Massive (m = k/2, where k = omega / c)
k0 = omega / sound_speed
mass = k0 / 2
kg_massive = klein_gordon(source, medium, omega=omega, mass=mass, params=params)
Comment on lines +43 to +52

# ── Visualisation ─────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

for ax, field, label in zip(
axes,
[kg_massless, kg_massive],
["Helmholtz (m = 0)", f"Klein-Gordon (m = k/2)"],
):
img = jnp.abs(field.on_grid[..., 0])
ax.imshow(img.T, cmap="inferno", origin="lower")
ax.set_title(label)
ax.set_xlabel("x [grid points]")
ax.set_ylabel("y [grid points]")

plt.tight_layout()
plt.savefig("klein_gordon_example.png", dpi=150)
plt.show()
print("Done.")
1 change: 1 addition & 0 deletions jwave/acoustics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .conversion import db2neper
from .operators import (
helmholtz,
klein_gordon,
laplacian_with_pml,
scale_source_helmholtz,
wavevector,
Expand Down
37 changes: 37 additions & 0 deletions jwave/acoustics/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,43 @@ def helmholtz(u: OnGrid, medium: Medium, *, omega=1.0, params=None) -> OnGrid:
return L + k


def klein_gordon_init_params(u: Field,
medium: Medium,
omega,
*args,
**kwargs):
return helmholtz.default_params(u, medium, omega)


@operator(init_params=klein_gordon_init_params)
def klein_gordon(u: Field,
medium: Medium,
*,
omega=1.0,
mass=0.0,
params=None) -> Field:
r"""Klein-Gordon operator: :math:`(\nabla^2 + k^2 - m^2)u`.

Extends the Helmholtz operator with a mass term. The Klein-Gordon
equation arises in relativistic wave mechanics and describes massive
scalar fields. Setting ``mass=0`` recovers the Helmholtz operator
exactly.

Args:
u (Field): Complex field.
medium (Medium): Medium object.
omega (float): Angular frequency.
mass (float): Mass parameter :math:`m`.
params (None, optional): Parameters for the underlying Helmholtz
operator.

Returns:
Field: Klein-Gordon operator applied to ``u``.
"""
h = helmholtz(u, medium, omega=omega, params=params)
return h - mass**2 * u
Comment on lines +311 to +337


def scale_source_helmholtz(source: Field, medium: Medium) -> Field:
if isinstance(medium.sound_speed, Field):
min_sos = functional(medium.sound_speed)(jnp.amin)
Expand Down
Loading