#!/usr/bin/env python3
"""
Real-time voice noise reduction with a Kalman filter (NumPy + sounddevice)

Requirements:
    pip install numpy sounddevice

Run:
    python realtime_kalman_denoise.py
"""

import numpy as np
import sounddevice as sd

# ——— Kalman filter parameters ———
Q = 1e-5    # Process noise covariance (smaller → trust model more)
R = 0.25    # Measurement noise covariance (≈ noise variance)
P = 1.0     # Initial estimation error covariance
x_prev = 0.0  # Initial state estimate

def kalman_filter(y):
    """
    One iteration of a 1-D Kalman filter with a random-walk model:
        xₖ₊₁ = xₖ + wₖ,   wₖ~N(0,Q)
        yₖ   = xₖ + vₖ,   vₖ~N(0,R)
    Returns the filtered sample and updates internal state.
    """
    global P, x_prev

    # 1) Predict
    x_pred = x_prev
    P_pred = P + Q

    # 2) Compute Kalman gain
    K = P_pred / (P_pred + R)

    # 3) Update estimate
    x_cur = x_pred + K * (y - x_pred)
    P = (1 - K) * P_pred

    # 4) Save for next iteration
    x_prev = x_cur
    return x_cur

def audio_callback(indata, outdata, frames, time, status):
    """
    sounddevice stream callback: processes 'frames' samples per channel.
    """
    if status:
        print(f"Stream status: {status}", flush=True)

    # Assume mono; indata.shape == (frames, 1)
    samples = indata[:, 0]
    filtered = np.empty_like(samples)

    # Apply Kalman filter sample-by-sample
    for i, s in enumerate(samples):
        filtered[i] = kalman_filter(s)

    # Write back to output (mono)
    outdata[:, 0] = filtered

def main():
    samplerate = 16000  # Hz (telephone quality)
    blocksize = 1024    # samples per block
    channels = 1        # mono

    try:
        with sd.Stream(
            samplerate=samplerate,
            blocksize=blocksize,
            dtype='float32',
            channels=channels,
            callback=audio_callback
        ):
            print("📢 Real-time Kalman denoising running (Ctrl+C to stop)…")
            while True:
                sd.sleep(1000)

    except KeyboardInterrupt:
        print("\n🛑 Interrupted by user, exiting.")

if __name__ == "__main__":
    main()