Files
MultiMicLocation/Simulation/mic_array_gd_localization.py

260 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft, ifft
class GCCPHATTOAGDLocalizer3D:
"""5x5 平面阵列:先用 GCC-PHAT 估计每个麦克风 TOA再用梯度下降估计 3D 声源位置。"""
def __init__(
self,
fs=24000,
c_sound=343.0,
grid_shape=(5, 5),
spacing=0.04,
seed=7,
):
self.fs = fs
self.c = c_sound
self.grid_shape = grid_shape
self.spacing = spacing
self.rng = np.random.default_rng(seed)
self.mic_pos = self._build_xy_plane_array(grid_shape, spacing)
self.num_mics = self.mic_pos.shape[0]
def _build_xy_plane_array(self, grid_shape, spacing):
nx, ny = grid_shape
xs = (np.arange(nx) - (nx - 1) / 2.0) * spacing
ys = (np.arange(ny) - (ny - 1) / 2.0) * spacing
xx, yy = np.meshgrid(xs, ys, indexing="xy")
zz = np.zeros_like(xx)
return np.column_stack((xx.ravel(), yy.ravel(), zz.ravel()))
def generate_source_signal(self, duration=0.08, f_low=300.0, f_high=4000.0):
n = int(self.fs * duration)
if n <= 0:
raise ValueError("duration 必须大于 0")
freqs = np.fft.rfftfreq(n, d=1.0 / self.fs)
spectrum = np.zeros(freqs.size, dtype=np.complex128)
band_mask = (freqs >= f_low) & (freqs <= min(f_high, self.fs / 2 - 1.0))
band_count = np.count_nonzero(band_mask)
if band_count == 0:
raise ValueError("duration 过短,无法形成有效宽带信号")
amp = self.rng.uniform(0.2, 1.0, size=band_count)
phase = self.rng.uniform(0.0, 2.0 * np.pi, size=band_count)
spectrum[band_mask] = amp * np.exp(1j * phase)
signal = np.fft.irfft(spectrum, n=n)
noise_std = max(0.03 * np.std(signal), 1e-4)
signal = signal + self.rng.normal(0.0, noise_std, size=n)
peak = np.max(np.abs(signal))
if peak > 1e-12:
signal = signal / peak
return signal
def apply_delay_freq_domain(self, signal, delay_sec):
n = len(signal)
padded = np.pad(signal, (0, n))
n_fft = len(padded)
sig_fft = fft(padded)
freqs = np.fft.fftfreq(n_fft, d=1.0 / self.fs)
phase_shift = np.exp(-1j * 2.0 * np.pi * freqs * delay_sec)
delayed = ifft(sig_fft * phase_shift).real
return delayed[:n]
def simulate_mic_signals(self, src_sig, source_pos, snr_db=20.0):
source_pos = np.asarray(source_pos, dtype=float)
if source_pos.shape != (3,):
raise ValueError("source_pos 必须是长度为 3 的向量 [x, y, z]")
if source_pos[2] <= 0:
raise ValueError("仅考虑前方声源,请保证 z > 0")
vec = source_pos[None, :] - self.mic_pos
dists = np.linalg.norm(vec, axis=1)
arrival_times = dists / self.c
mic_signals = []
for dist, delay in zip(dists, arrival_times):
delayed = self.apply_delay_freq_domain(src_sig, delay)
atten = 1.0 / max(dist, 1e-3)
mic_signals.append(delayed * atten)
mic_signals = np.asarray(mic_signals)
for i in range(self.num_mics):
sig_power = np.mean(mic_signals[i] ** 2)
noise_power = sig_power / (10.0 ** (snr_db / 10.0))
noise_std = np.sqrt(max(noise_power, 1e-12))
mic_signals[i] += self.rng.normal(0.0, noise_std, size=mic_signals.shape[1])
return mic_signals, arrival_times
def gcc_phat_delay(self, sig, refsig, max_tau=0.01, interp=16):
n = sig.shape[0] + refsig.shape[0]
sig_fft = fft(sig, n=n)
ref_fft = fft(refsig, n=n)
cross_power = sig_fft * np.conj(ref_fft)
cross_power /= np.abs(cross_power) + 1e-15
corr = ifft(cross_power, n=interp * n).real
max_shift = int(interp * n / 2)
if max_tau is not None:
max_shift = min(int(interp * self.fs * max_tau), max_shift)
corr_window = np.concatenate((corr[-max_shift:], corr[: max_shift + 1]))
shift = int(np.argmax(corr_window)) - max_shift
tau = shift / float(interp * self.fs)
return tau
def estimate_toa_with_gcc_phat(self, mic_signals, src_sig):
toa = np.zeros(self.num_mics)
for i in range(self.num_mics):
toa[i] = self.gcc_phat_delay(mic_signals[i], src_sig)
return toa
def predict_toa(self, source_pos):
vec = source_pos[None, :] - self.mic_pos
dists = np.linalg.norm(vec, axis=1)
return dists / self.c
def localize_with_gradient_descent(
self,
measured_toa,
init_pos=np.array([0.0, 0.0, 0.40]),
lr=0.10,
max_iters=2000,
decay=0.001,
z_min=0.02,
tol=1e-9,
):
x = np.asarray(init_pos, dtype=float).copy()
if x.shape != (3,):
raise ValueError("init_pos 必须是长度为 3 的向量 [x, y, z]")
x[2] = max(x[2], z_min)
target_ranges = measured_toa * self.c
loss_history = []
for i in range(max_iters):
vec = x[None, :] - self.mic_pos
dists = np.linalg.norm(vec, axis=1)
dists = np.maximum(dists, 1e-8)
err = dists - target_ranges
loss = 0.5 * np.mean(err ** 2)
loss_history.append(loss)
grad = np.mean(err[:, None] * (vec / dists[:, None]), axis=0)
step = lr / (1.0 + decay * i)
x = x - step * grad
x[2] = max(x[2], z_min)
if np.linalg.norm(grad) < tol:
break
pred_toa = self.predict_toa(x)
return x, np.asarray(loss_history), pred_toa
def visualize(self, true_pos, est_pos, loss_history, measured_toa, fitted_toa):
fig = plt.figure(figsize=(16, 5))
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
ax1.scatter(self.mic_pos[:, 0], self.mic_pos[:, 1], self.mic_pos[:, 2], c="royalblue", s=28, label="Microphones")
ax1.scatter(true_pos[0], true_pos[1], true_pos[2], c="limegreen", marker="*", s=220, label="True source")
ax1.scatter(est_pos[0], est_pos[1], est_pos[2], c="crimson", marker="^", s=120, label="Estimated source")
ax1.plot(
[true_pos[0], est_pos[0]],
[true_pos[1], est_pos[1]],
[true_pos[2], est_pos[2]],
"k--",
linewidth=1.2,
)
ax1.set_title("3D Localization (z > 0)")
ax1.set_xlabel("x (m)")
ax1.set_ylabel("y (m)")
ax1.set_zlabel("z (m)")
ax1.legend(loc="upper left")
ax1.view_init(elev=24, azim=-55)
ax2 = fig.add_subplot(1, 3, 2)
ax2.plot(loss_history, color="tab:orange", linewidth=2)
if np.all(loss_history > 0):
ax2.set_yscale("log")
ax2.set_title("Gradient Descent Loss")
ax2.set_xlabel("Iteration")
ax2.set_ylabel("MSE of range error (m^2)")
ax2.grid(alpha=0.3)
ax3 = fig.add_subplot(1, 3, 3)
mic_index = np.arange(self.num_mics)
ax3.plot(mic_index, measured_toa * 1e6, "o-", label="Measured TOA (GCC-PHAT)")
ax3.plot(mic_index, fitted_toa * 1e6, "x--", label="Predicted TOA")
ax3.set_title("TOA Fit Across 25 Mics")
ax3.set_xlabel("Mic index")
ax3.set_ylabel("Arrival time (us)")
ax3.grid(alpha=0.3)
ax3.legend()
fig.tight_layout()
plt.show()
def main():
localizer = GCCPHATTOAGDLocalizer3D(
fs=24000,
c_sound=343.0,
grid_shape=(5, 5),
spacing=0.04,
seed=7,
)
true_source = np.array([0.12, -0.08, 0.55])
src_signal = localizer.generate_source_signal(duration=0.08)
mic_signals, true_arrival_times = localizer.simulate_mic_signals(
src_signal,
true_source,
snr_db=20.0,
)
measured_toa = localizer.estimate_toa_with_gcc_phat(mic_signals, src_signal)
est_source, loss_history, fitted_toa = localizer.localize_with_gradient_descent(
measured_toa,
init_pos=np.array([0.0, 0.0, 0.40]),
lr=0.10,
max_iters=2000,
decay=0.001,
z_min=0.02,
tol=1e-9,
)
pos_err = np.linalg.norm(est_source - true_source)
toa_rmse_us = np.sqrt(np.mean((measured_toa - true_arrival_times) ** 2)) * 1e6
fit_rmse_us = np.sqrt(np.mean((fitted_toa - measured_toa) ** 2)) * 1e6
print("=== 5x5 平面阵列 + GCC-PHAT(TOA) + 梯度下降定位 ===")
print(f"真实声源位置 [x, y, z] (m): {true_source}")
print(f"估计声源位置 [x, y, z] (m): {est_source}")
print(f"位置误差: {pos_err:.4f} m")
print(f"GCC-PHAT 到达时间估计 RMSE (相对真值): {toa_rmse_us:.2f} us")
print(f"估计位置反推 TOA 拟合 RMSE: {fit_rmse_us:.2f} us")
print(f"迭代次数: {len(loss_history)}")
localizer.visualize(
true_pos=true_source,
est_pos=est_source,
loss_history=loss_history,
measured_toa=measured_toa,
fitted_toa=fitted_toa,
)
if __name__ == "__main__":
main()