260 lines
8.9 KiB
Python
260 lines
8.9 KiB
Python
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from scipy.fft import fft, ifft
|
||
|
||
|
||
class GCCPHATTOAGDLocalizer3D:
|
||
"""5x5 平面阵列:先用 GCC-PHAT 估计每个麦克风 TOA,再用梯度下降估计 3D 声源位置。"""
|
||
|
||
def __init__(
|
||
self,
|
||
fs=24000,
|
||
c_sound=343.0,
|
||
grid_shape=(5, 5),
|
||
spacing=0.04,
|
||
seed=7,
|
||
):
|
||
self.fs = fs
|
||
self.c = c_sound
|
||
self.grid_shape = grid_shape
|
||
self.spacing = spacing
|
||
self.rng = np.random.default_rng(seed)
|
||
|
||
self.mic_pos = self._build_xy_plane_array(grid_shape, spacing)
|
||
self.num_mics = self.mic_pos.shape[0]
|
||
|
||
def _build_xy_plane_array(self, grid_shape, spacing):
|
||
nx, ny = grid_shape
|
||
xs = (np.arange(nx) - (nx - 1) / 2.0) * spacing
|
||
ys = (np.arange(ny) - (ny - 1) / 2.0) * spacing
|
||
xx, yy = np.meshgrid(xs, ys, indexing="xy")
|
||
zz = np.zeros_like(xx)
|
||
return np.column_stack((xx.ravel(), yy.ravel(), zz.ravel()))
|
||
|
||
def generate_source_signal(self, duration=0.08, f_low=300.0, f_high=4000.0):
|
||
n = int(self.fs * duration)
|
||
if n <= 0:
|
||
raise ValueError("duration 必须大于 0")
|
||
|
||
freqs = np.fft.rfftfreq(n, d=1.0 / self.fs)
|
||
spectrum = np.zeros(freqs.size, dtype=np.complex128)
|
||
band_mask = (freqs >= f_low) & (freqs <= min(f_high, self.fs / 2 - 1.0))
|
||
band_count = np.count_nonzero(band_mask)
|
||
if band_count == 0:
|
||
raise ValueError("duration 过短,无法形成有效宽带信号")
|
||
|
||
amp = self.rng.uniform(0.2, 1.0, size=band_count)
|
||
phase = self.rng.uniform(0.0, 2.0 * np.pi, size=band_count)
|
||
spectrum[band_mask] = amp * np.exp(1j * phase)
|
||
|
||
signal = np.fft.irfft(spectrum, n=n)
|
||
noise_std = max(0.03 * np.std(signal), 1e-4)
|
||
signal = signal + self.rng.normal(0.0, noise_std, size=n)
|
||
|
||
peak = np.max(np.abs(signal))
|
||
if peak > 1e-12:
|
||
signal = signal / peak
|
||
return signal
|
||
|
||
def apply_delay_freq_domain(self, signal, delay_sec):
|
||
n = len(signal)
|
||
padded = np.pad(signal, (0, n))
|
||
n_fft = len(padded)
|
||
sig_fft = fft(padded)
|
||
freqs = np.fft.fftfreq(n_fft, d=1.0 / self.fs)
|
||
phase_shift = np.exp(-1j * 2.0 * np.pi * freqs * delay_sec)
|
||
delayed = ifft(sig_fft * phase_shift).real
|
||
return delayed[:n]
|
||
|
||
def simulate_mic_signals(self, src_sig, source_pos, snr_db=20.0):
|
||
source_pos = np.asarray(source_pos, dtype=float)
|
||
if source_pos.shape != (3,):
|
||
raise ValueError("source_pos 必须是长度为 3 的向量 [x, y, z]")
|
||
if source_pos[2] <= 0:
|
||
raise ValueError("仅考虑前方声源,请保证 z > 0")
|
||
|
||
vec = source_pos[None, :] - self.mic_pos
|
||
dists = np.linalg.norm(vec, axis=1)
|
||
arrival_times = dists / self.c
|
||
|
||
mic_signals = []
|
||
for dist, delay in zip(dists, arrival_times):
|
||
delayed = self.apply_delay_freq_domain(src_sig, delay)
|
||
atten = 1.0 / max(dist, 1e-3)
|
||
mic_signals.append(delayed * atten)
|
||
|
||
mic_signals = np.asarray(mic_signals)
|
||
for i in range(self.num_mics):
|
||
sig_power = np.mean(mic_signals[i] ** 2)
|
||
noise_power = sig_power / (10.0 ** (snr_db / 10.0))
|
||
noise_std = np.sqrt(max(noise_power, 1e-12))
|
||
mic_signals[i] += self.rng.normal(0.0, noise_std, size=mic_signals.shape[1])
|
||
|
||
return mic_signals, arrival_times
|
||
|
||
def gcc_phat_delay(self, sig, refsig, max_tau=0.01, interp=16):
|
||
n = sig.shape[0] + refsig.shape[0]
|
||
sig_fft = fft(sig, n=n)
|
||
ref_fft = fft(refsig, n=n)
|
||
|
||
cross_power = sig_fft * np.conj(ref_fft)
|
||
cross_power /= np.abs(cross_power) + 1e-15
|
||
|
||
corr = ifft(cross_power, n=interp * n).real
|
||
|
||
max_shift = int(interp * n / 2)
|
||
if max_tau is not None:
|
||
max_shift = min(int(interp * self.fs * max_tau), max_shift)
|
||
|
||
corr_window = np.concatenate((corr[-max_shift:], corr[: max_shift + 1]))
|
||
shift = int(np.argmax(corr_window)) - max_shift
|
||
tau = shift / float(interp * self.fs)
|
||
return tau
|
||
|
||
def estimate_toa_with_gcc_phat(self, mic_signals, src_sig):
|
||
toa = np.zeros(self.num_mics)
|
||
for i in range(self.num_mics):
|
||
toa[i] = self.gcc_phat_delay(mic_signals[i], src_sig)
|
||
return toa
|
||
|
||
def predict_toa(self, source_pos):
|
||
vec = source_pos[None, :] - self.mic_pos
|
||
dists = np.linalg.norm(vec, axis=1)
|
||
return dists / self.c
|
||
|
||
def localize_with_gradient_descent(
|
||
self,
|
||
measured_toa,
|
||
init_pos=np.array([0.0, 0.0, 0.40]),
|
||
lr=0.10,
|
||
max_iters=2000,
|
||
decay=0.001,
|
||
z_min=0.02,
|
||
tol=1e-9,
|
||
):
|
||
x = np.asarray(init_pos, dtype=float).copy()
|
||
if x.shape != (3,):
|
||
raise ValueError("init_pos 必须是长度为 3 的向量 [x, y, z]")
|
||
|
||
x[2] = max(x[2], z_min)
|
||
target_ranges = measured_toa * self.c
|
||
|
||
loss_history = []
|
||
for i in range(max_iters):
|
||
vec = x[None, :] - self.mic_pos
|
||
dists = np.linalg.norm(vec, axis=1)
|
||
dists = np.maximum(dists, 1e-8)
|
||
|
||
err = dists - target_ranges
|
||
loss = 0.5 * np.mean(err ** 2)
|
||
loss_history.append(loss)
|
||
|
||
grad = np.mean(err[:, None] * (vec / dists[:, None]), axis=0)
|
||
step = lr / (1.0 + decay * i)
|
||
x = x - step * grad
|
||
x[2] = max(x[2], z_min)
|
||
|
||
if np.linalg.norm(grad) < tol:
|
||
break
|
||
|
||
pred_toa = self.predict_toa(x)
|
||
return x, np.asarray(loss_history), pred_toa
|
||
|
||
def visualize(self, true_pos, est_pos, loss_history, measured_toa, fitted_toa):
|
||
fig = plt.figure(figsize=(16, 5))
|
||
|
||
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
|
||
ax1.scatter(self.mic_pos[:, 0], self.mic_pos[:, 1], self.mic_pos[:, 2], c="royalblue", s=28, label="Microphones")
|
||
ax1.scatter(true_pos[0], true_pos[1], true_pos[2], c="limegreen", marker="*", s=220, label="True source")
|
||
ax1.scatter(est_pos[0], est_pos[1], est_pos[2], c="crimson", marker="^", s=120, label="Estimated source")
|
||
ax1.plot(
|
||
[true_pos[0], est_pos[0]],
|
||
[true_pos[1], est_pos[1]],
|
||
[true_pos[2], est_pos[2]],
|
||
"k--",
|
||
linewidth=1.2,
|
||
)
|
||
ax1.set_title("3D Localization (z > 0)")
|
||
ax1.set_xlabel("x (m)")
|
||
ax1.set_ylabel("y (m)")
|
||
ax1.set_zlabel("z (m)")
|
||
ax1.legend(loc="upper left")
|
||
ax1.view_init(elev=24, azim=-55)
|
||
|
||
ax2 = fig.add_subplot(1, 3, 2)
|
||
ax2.plot(loss_history, color="tab:orange", linewidth=2)
|
||
if np.all(loss_history > 0):
|
||
ax2.set_yscale("log")
|
||
ax2.set_title("Gradient Descent Loss")
|
||
ax2.set_xlabel("Iteration")
|
||
ax2.set_ylabel("MSE of range error (m^2)")
|
||
ax2.grid(alpha=0.3)
|
||
|
||
ax3 = fig.add_subplot(1, 3, 3)
|
||
mic_index = np.arange(self.num_mics)
|
||
ax3.plot(mic_index, measured_toa * 1e6, "o-", label="Measured TOA (GCC-PHAT)")
|
||
ax3.plot(mic_index, fitted_toa * 1e6, "x--", label="Predicted TOA")
|
||
ax3.set_title("TOA Fit Across 25 Mics")
|
||
ax3.set_xlabel("Mic index")
|
||
ax3.set_ylabel("Arrival time (us)")
|
||
ax3.grid(alpha=0.3)
|
||
ax3.legend()
|
||
|
||
fig.tight_layout()
|
||
plt.show()
|
||
|
||
|
||
def main():
|
||
localizer = GCCPHATTOAGDLocalizer3D(
|
||
fs=24000,
|
||
c_sound=343.0,
|
||
grid_shape=(5, 5),
|
||
spacing=0.04,
|
||
seed=7,
|
||
)
|
||
|
||
true_source = np.array([0.12, -0.08, 0.55])
|
||
|
||
src_signal = localizer.generate_source_signal(duration=0.08)
|
||
mic_signals, true_arrival_times = localizer.simulate_mic_signals(
|
||
src_signal,
|
||
true_source,
|
||
snr_db=20.0,
|
||
)
|
||
|
||
measured_toa = localizer.estimate_toa_with_gcc_phat(mic_signals, src_signal)
|
||
|
||
est_source, loss_history, fitted_toa = localizer.localize_with_gradient_descent(
|
||
measured_toa,
|
||
init_pos=np.array([0.0, 0.0, 0.40]),
|
||
lr=0.10,
|
||
max_iters=2000,
|
||
decay=0.001,
|
||
z_min=0.02,
|
||
tol=1e-9,
|
||
)
|
||
|
||
pos_err = np.linalg.norm(est_source - true_source)
|
||
toa_rmse_us = np.sqrt(np.mean((measured_toa - true_arrival_times) ** 2)) * 1e6
|
||
fit_rmse_us = np.sqrt(np.mean((fitted_toa - measured_toa) ** 2)) * 1e6
|
||
|
||
print("=== 5x5 平面阵列 + GCC-PHAT(TOA) + 梯度下降定位 ===")
|
||
print(f"真实声源位置 [x, y, z] (m): {true_source}")
|
||
print(f"估计声源位置 [x, y, z] (m): {est_source}")
|
||
print(f"位置误差: {pos_err:.4f} m")
|
||
print(f"GCC-PHAT 到达时间估计 RMSE (相对真值): {toa_rmse_us:.2f} us")
|
||
print(f"估计位置反推 TOA 拟合 RMSE: {fit_rmse_us:.2f} us")
|
||
print(f"迭代次数: {len(loss_history)}")
|
||
|
||
localizer.visualize(
|
||
true_pos=true_source,
|
||
est_pos=est_source,
|
||
loss_history=loss_history,
|
||
measured_toa=measured_toa,
|
||
fitted_toa=fitted_toa,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|