관리 메뉴

HAMA 블로그

파이썬을 통한 웨이블릿 회귀 본문

통계 & 머신러닝 & 딥러닝

파이썬을 통한 웨이블릿 회귀

[하마] 이승현 (wowlsh93@gmail.com) 2016. 7. 4. 17:36

Wavelet Regression in Python

Last week I needed to get my head around wavelet regression techniques for a project I am working on. This post will show how to do basic wavelet regression in Python using PyWavelets. For references I used Chapter 9 in Wasserman's All of Non-Parametric Statistics, Ogden's Essential Wavelets for Statistical Applications and Data Analysis, and Donoho and Johnstone's "Ideal spational adaptation by wavelet shrinkage" [pdf]. I also recommend "An Introduction to Wavelets" [pdf] by Amara Graps for a high level overview and a brief historical treatment of the development of wavelet analysis. These references will provide much more detail and guidance than I will here.

Wavelets are mathematical functions that provide an orthonormal basis for functions in the L2 space. The use of wavelets is akin to the use of sines and cosines to represent L2 functions in Fourier analysis. In fact, many treatments on wavelets start the discussion by reviewing Fourier series represenations of functions. Legendre polynomials are another complete orthonormal system.

The simplest wavelet is the Haar wavelet. It is not much used in practice but it is convenient to fix ideas. The Haar function is given

ψ(x)={1,x[0,1)1,x[12,1)0,otherwise

In [1]:
import pywt
import matplotlib.pyplot as plt
In [2]:
w = pywt.Wavelet('Haar')
phi, psi, x = w.wavefun(level=10)

fig, ax = plt.subplots()
ax.set_xlim(-.02,1.02)
ax.plot(x, psi);

The location in the domain and the range of this mother wavelet function can be conrolled by a translation index k and a dilation index j, respectively.

ψj,k(x)=2j/2ψ(2jxk)

A wavelet system is fully defined once a scaling function, or father wavelet, is defined. The scaling function ensures that the orthonormal basis covers the original space of a function without having to use an infinite number of mother wavelets. The father wavelet for the Haar system is

ϕ(x)=I[0,1)(x)

where IA(x) is an indicator function that equals 1 if x is in the set A. You can see that any dilation and translation can cover the original function space. While the Haar function is a good introduction, the rest of the code will use the most nearly symmetric Debauchies wavelet with N=8.

In [3]:
db8 = pywt.Wavelet('db8')
scaling, wavelet, x = db8.wavefun()

fig, axes = plt.subplots(1, 2, sharey=True, figsize=(8,6))
ax1, ax2 = axes

ax1.plot(x, scaling);
ax1.set_title('Scaling function, N=8');
ax1.set_ylim(-1.2, 1.2);

ax2.set_title('Wavelet, N=8');
ax2.tick_params(labelleft=False);
ax2.plot(x-x.mean(), wavelet);

fig.tight_layout()

For the current purposes, the objective of wavelet analysis is to recover the unknown function f from noisy data

yi=f(xi)+ei,i{1,,n}

with xi=i/n without loss of generality and eiN(0,σ2). The expansion of the function f as a finite sum of wavelets can be achieved as

fJ(x)=αϕ(x)+j=0J1k=02j1βjkψjk(x)

where

α=01f(x)ϕ(x)dxβjk=01f(x)ψjk(x)dx.

α and βjk are called the scaling coefficients and the detail coefficients, respectively. The basic idea is that the detail coefficients capture the coarser details of the function while the scaling, or smoothing, coefficients capture the overall functional form.

Donoho and Johnstone use four functions that mimic properties of empirical data in domains where wavelets might be useful, Bumps, Blocks, HeaviSine, and Doppler. For completeness, they are defined here.

In [4]:
def doppler(x):
    """
    Parameters
    ----------
    x : array-like
        Domain of x is in (0,1]
 
    """
    if not np.all((x >= 0) & (x <= 1)):
        raise ValueError("Domain of doppler is x in (0,1]")
    return np.sqrt(x*(1-x))*np.sin((2.1*np.pi)/(x+.05))
 
def blocks(x):
    """
    Piecewise constant function with jumps at t.
 
    Constant scaler is not present in Donoho and Johnstone.
    """
    K = lambda x : (1 + np.sign(x))/2.
    t = np.array([[.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]]).T
    h = np.array([[4, -5, 3, -4, 5, -4.2, 2.1, 4.3, -3.1, 2.1, -4.2]]).T
    return 3.655606 * np.sum(h*K(x-t), axis=0)
 
def bumps(x):
    """
    A sum of bumps with locations t at the same places as jumps in blocks.
    The heights h and widths s vary and the individual bumps are of the
    form K(t) = 1/(1+|x|)**4
    """
    K = lambda x : (1. + np.abs(x)) ** -4.
    t = np.array([[.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]]).T
    h = np.array([[4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 2.1, 4.2]]).T
    w = np.array([[.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]]).T
    return np.sum(h*K((x-t)/w), axis=0)
 
def heavisine(x):
    """
    Sinusoid of period 1 with two jumps at t = .3 and .72
    """
    return 4 * np.sin(4*np.pi*x) - np.sign(x - .3) - np.sign(.72 - x)
In [5]:
x = np.linspace(0,1,2**11)
dop = doppler(x)
blk = blocks(x)
bmp = bumps(x)
hsin = heavisine(x)
In [6]:
fig, axes = plt.subplots(2, 2, figsize=(10,10))
ax1 = axes[0,0]
ax2 = axes[0,1]
ax3 = axes[1,0]
ax4 = axes[1,1]

ax1.plot(x,dop)
ax1.set_title("Doppler")

ax2.plot(x,blk)
ax2.set_title("Blocks")

ax3.plot(x,bmp)
ax3.set_title("Bumps")

ax4.set_title("HeaviSine")
ax4.plot(x,hsin)

for ax in fig.axes:
    ax.tick_params(labelbottom=False, labelleft=False, bottom=False, 
                   top=False, left=False, right=False)

fig.tight_layout();

The wavelet coefficients fully describe some data. First, generate some data from the noisy doppler function and apply a discrete wavelet transform to recover the smoothing and detail coefficients, then have a look at a common diagnostic plot of the wavelet coefficients using coef_pyramid_plot.

In [7]:
def coef_pyramid_plot(coefs, first=0, scale='uniform', ax=None):
    """
    Parameters
    ----------
    coefs : array-like
        Wavelet Coefficients. Expects an iterable in order Cdn, Cdn-1, ...,
        Cd1, Cd0.
    first : int, optional
        The first level to plot.
    scale : str {'uniform', 'level'}, optional
        Scale the coefficients using the same scale or independently by
        level.
    ax : Axes, optional
        Matplotlib Axes instance

    Returns
    -------
    Figure : Matplotlib figure instance
        Either the parent figure of `ax` or a new pyplot.Figure instance if
        `ax` is None.
    """

    if ax is None:
        import matplotlib.pyplot as plt
        fig = plt.figure()
        ax = fig.add_subplot(111, axisbg='lightgrey')
    else:
        fig = ax.figure

    n_levels = len(coefs)
    n = 2**(n_levels - 1) # assumes periodic

    if scale == 'uniform':
        biggest = [np.max(np.abs(np.hstack(coefs)))] * n_levels
    else:
        # multiply by 2 so the highest bars only take up .5
        biggest = [np.max(np.abs(i))*2 for i in coefs]

    for i in range(first,n_levels):
        x = np.linspace(2**(n_levels - 2 - i), n - 2**(n_levels - 2 - i), 2**i)
        ymin = n_levels - i - 1 + first
        yheight = coefs[i]/biggest[i]
        ymax = yheight + ymin
        ax.vlines(x, ymin, ymax, linewidth=1.1)

    ax.set_xlim(0,n)
    ax.set_ylim(first - 1, n_levels)
    ax.yaxis.set_ticks(np.arange(n_levels-1,first-1,-1))
    ax.yaxis.set_ticklabels(np.arange(first,n_levels))
    ax.tick_params(top=False, right=False, direction='out', pad=6)
    ax.set_ylabel("Levels", fontsize=14)
    ax.grid(True, alpha=.85, color='white', axis='y', linestyle='-')
    ax.set_title('Wavelet Detail Coefficients', fontsize=16,
            position=(.5,1.05))
    fig.subplots_adjust(top=.89)

    return fig

Generate the data and get the coefficients using the multilevel discrete wavelet transform. Plot the true coefficients and the noisy ones.

In [8]:
from scipy import stats
import numpy as np

np.random.seed(12345)
blck = blocks(np.linspace(0,1,2**11))
nblck = blck + stats.norm().rvs(2**11)

true_coefs = pywt.wavedec(blck, 'db8', level=11, mode='per')
noisy_coefs = pywt.wavedec(nblck, 'db8', level=11, mode='per')

fig, axes = plt.subplots(2, 1, figsize=(9,14), sharex=True)

fig = coef_pyramid_plot(true_coefs[1:], ax=axes[0]) # omit smoothing coefs
axes[0].set_title("True Wavelet Detail Coefficients");

fig = coef_pyramid_plot(noisy_coefs[1:], ax=axes[1]) ;
axes[1].set_title("Noisy Wavelet Detail Coefficients");

fig.tight_layout()

Notice that most of the coefficients of the true signal are zero. This is the idea of sparseness, most functions, smooth or otherwise, have a sparse respresentation in a wavelet basis. The detail coefficients are non-zero where the block function is not flat. The detail coefficients of the noisy signal have many more non-zero coefficients at the higher resolutions from the added noise. To recover the signal from the noisy coefficients, you threshold the coefficients. Essentially, thresholding sets many of the coefficients to zero by assuming that they are noise. There are a number of ways to go about performing thresholding. Details can be found in the given references. I will simply apply soft thresholding using the universal threshold. The universal threshold is defined

λ=σ^2log(N)

where σ^ is a robust estimator of the standard deviation of the finest level detail coefficients. Here, I use the standardized median absolute deviation available from statsmodels.

σ^=MAD(βJ1,)

In [9]:
from statsmodels.robust import stand_mad

sigma = stand_mad(noisy_coefs[-1])
uthresh = sigma*np.sqrt(2*np.log(len(nblck)))

denoised = noisy_coefs[:]

denoised[1:] = (pywt.thresholding.soft(i, value=uthresh) for i in denoised[1:])

We can recover the signal by applying the inverse discrete wavelet transform to the thresholded coefficients.

In [10]:
signal = pywt.waverec(denoised, 'db8', mode='per')

fig, axes = plt.subplots(1, 2, sharey=True, sharex=True,
                         figsize=(10,8))
ax1, ax2 = axes

ax1.plot(signal)
ax1.set_xlim(0,2**10)
ax1.set_title("Recovered Signal")
ax1.margins(.1)

ax2.plot(nblck)
ax2.set_title("Noisy Signal")

for ax in fig.axes:
    ax.tick_params(labelbottom=False, top=False, bottom=False, left=False, 
                 right=False)
    
fig.tight_layout()

PyWavelets is a really cool project. But it could use a little more attention. It would be nice if some of the library functions were wrapped up so that they return wavelet objects that would be easier to work with. For instance, if wavedec returned a discrete wavelet object that the thresholding functions also accepted, this would save me some keystrokes in the above. I do not have time to take on another project, but it is low hanging fruit if someone else does. The code is MIT licensed and might even find a nice home in SciPy.


Comments