# Student version!!

### Block coordinate minimization (BCM) for 1D dictionary learning of wave+spike data

In-class group activity, EECS 598-006 W19

Given a set of training signals $X = [x_1,\ldots,x_L]$,
this notebook performs *dictionary learning* by solving
$$
\hat{D} = \arg \min_D \min_Z \frac{1}{2} \| X - D Z \|^2_F + \beta \|Z\|_0
,$$
subject to the constraint that
each column of $D$ must have unit norm.

This notebook considers signals that are generated
as the sum of
a smooth component
plus a few Kronecker impulses.

This notebook applies BCM to update one column of $D$
at a time,
i.e., one atom,
and then one corresponding column of $C = Z'$,
looping over all columns of $D$
to make one iteration.
This is the "sum of outer products" (SOUP) approach
to dictionary learning,
developed by Sai Ravishankar et al.,
at the University of Michigan.

2019-04-10 Jeff Fessler, University of Michigan

In [None]:
# Packages needed
using MIRT: jim
using FFTW: idct
using LinearAlgebra: I, norm, pinv, svd
using Plots; default(markerstrokecolor=nothing)
using LaTeXStrings
using Random: seed!, randperm

In [None]:
# Define basis functions for generating 1D training signals.
# This 1D problem is small enough that we use the full DCT matrix and identity matrix.
# For a large-scale problem, we would make D1 and D2 a LinearMap instead.
#
N = 2^7
K1 = 10 # number of dct terms
B1 = idct(Matrix(I,N,N)[:,1:K1], 1) # inverse DCT basis
K2 = 12 # number of impulse terms
#K2 = N # this would be the ideal value but it runs too slow for in-class work
seed!(0); tmp = sort(randperm(N)[1:K2])
B2 = Matrix(I,N,N)[:,tmp]
jim([B1 B2]')
;

In [None]:
scatter(B1, label="", xlabel=L"n", ylabel=L"D[:,k]", title="$K1 iDCT basis atoms")
plot!(B1, label="", color=:black)

In [None]:
# generate ideal training signals using the basis with a couple spikes
L = 10^4
seed!(0)
tmp1 = rand(K1,L) .< 0.05 # 1/ndct
Z1 = 3randn(K1,L) .* tmp1
Xtrue = B1 * Z1
tmp2 = rand(K2,L) .< 0.05 # 1/N
Z2 = randn(K2,L) .* tmp2
Xtrue .+= B2 * Z2
Dtrue = [B1 B2]
@assert isapprox(sum(Dtrue.^2,dims=1)[:], ones(K1+K2))
Ztrue = [Z1; Z2]
pctnz = Z -> round(count(Z .!= 0) / length(Z) * 100, digits=3)
@show pctnz(Ztrue)
@assert isapprox(Xtrue, Dtrue*Ztrue)
jim(Dtrue')
scatter(Xtrue[:,1:5], label="", xlabel=L"n", ylabel=L"x_l[n] \mbox{ true}", title="Xtrue")
plot!(Xtrue[:,1:5], label="", color=:black)

In [None]:
# make noisy training signals for DL
sig = 0.01
seed!(0); X = Xtrue + sig * randn(N,L)
scatter(X[:,1:5], label="", xlabel=L"n", ylabel=L"x_l[n] \mbox{ noisy}")
plot!(X[:,1:5], label="", color=:black)

In [None]:
#savefig("dl-1d-soup-x.pdf")

In [None]:
# utility functions for dictionaries
#
# sort dictionary D to best align with true dictionary for nice display
function sorter(D, Dtrue)
    #Dout = copy(D)
    @assert size(D,1) == size(Dtrue,1)
    @assert size(D,2) >= size(Dtrue,2)
    K = size(Dtrue,2)
    klist = 1:size(D,2)
    korder = zeros(Int, K)
    for k=1:K
        # unnnormalized inner product, assume atoms have unit norm
        tmp = (Dtrue[:,k]' * D[:,klist])[:]
        tmp = findmax(abs.(tmp))[2]
        kbest = klist[tmp]
        korder[k] = kbest
        klist = klist[klist .!= kbest]
    end
    return [korder; klist]
end

# functions to make max value positive and normalize
signs = D -> sign.(D[findmax(abs.(D), dims=1)[2]])
normalize = D -> D ./ sqrt.(sum(abs.(D).^2, dims=1)) # normalize!

seed!(0); tmp = normalize(rand(5,6))
@assert sorter([tmp normalize(rand(5,3))], tmp) == 1:9 # test

In [None]:
# initial guess of the dictionary D and coefficients Z, via PCA
seed!(0)
#K = 2*N # overcomplete
K = K1 + K2 # cheat for now
#D0 = randn(N,K) # initial dictionary - random
D0 = svd(X).U[:,1:min(K,N)] # PCA
if K > N
    D0 = normalize([D0 randn(N,K-N)])
end
D0 = D0[:,sorter(D0, Dtrue)]
D0 = D0 .* signs(D0)

Z0 = pinv(D0) * X # think about why we use pinv here!
# extrema(Z0)
plot(jim(Dtrue', "Dtrue"), jim(D0', "D0"))

### Perform BCM for DL via SOUP

In [None]:
# at least one of these shrinkage functions should be useful
soft = (z,b) -> sign(z) * max(abs(z) - b, 0)
hard = (z,b) -> z * (abs(z) > sqrt(2b))

In [None]:
# loop over 50 iterations of BCM here
# update D[:,1] then C[:,1] then D[:,2], ... C[:,K]
# where C = Z'

reg = 0.05 # beta
cost = (D,Z) -> 0.5 * norm(X - D*Z)^2 + reg * count(Z .!= 0) # "0-norm"

niter = 60
costs = zeros(niter + 1)
D = copy(D0)
Z = copy(Z0)
costs[1] = cost(D, Z)
R = X - D * Z # residual state
C = Z'
for iter = 1:niter
    for k = 1:K
        ck = @view C[:,k]
        norm_ck = norm(ck)
        
        # put your dictionary atom update code here:

        # ...

        # be sure to keep the residual "R" up-to-date!
        
        # put your coefficient vector update here:

    end
    costs[iter+1] = cost(D, C')

    if mod(iter, 5) == 0
        @show iter, pctnz(C) # monitor progress of sparsity
    end
end
Z = C'
scatter(0:niter, costs, label="cost") # plot cost vs iteration

In [None]:
# sort and sign "correct" the estimated dictionary for nice display
# you sould see that the spike part of the dictionary looks very good
# but the smooth part is hard to tell, even when sorted using sorter()
tmp = sorter(D, Dtrue)
D = D[:,tmp]
Z = Z[tmp,:]
tmp = signs(D)
D = D .* tmp
Z = Z .* tmp'
plot(jim(Dtrue', "Dtrue"), jim(D0', L"D_0"), jim(D', L"\hat{D}"), layout=(1,3))

In [None]:
#savefig("dl-1d-soup-d.pdf")

In [None]:
# Plot the estimated dictionary
scatter(D, title=L"\hat{D}", xlabel=L"n", ylabel=L"\hat{d}_k[n]", label="")
#scatter!(legend=:bottomright)

In [None]:
# examine the representation error and the cost function
# think about these values!
@show norm(X - D0*Z0)^2/(N*L*sig^2), cost(D0,Z0)
@show norm(X - D*Z)^2/(N*L*sig^2), cost(D,Z);

### Count the number of non-zero coefficients and report that value to Canvas

In [None]:
count(Z .!= 0)

In [None]:
# test the learned dictionary by using it for sparse coding some new test data
# first generate the data:
seed!(7)
Ltest = 10^3
tmp1 = rand(K1,Ltest) .< 0.05 # 1/ndct
Z1test = 3randn(K1,Ltest) .* tmp1
Xtest = B1 * Z1test
tmp2 = rand(K2,Ltest) .< 0.05 # 1/N
Z2test = randn(K2,Ltest) .* tmp2
Xtest .+= B2 * Z2test
Ytest = Xtest + sig * randn(N, Ltest) # noise
;

In [None]:
# BCM sparse coding routine
# study this function!
# argmin_z 0.5*|D Z - X|_F^2 + reg |Z|_0 where columns of D have unit norm
function sparse_code_cd(X, D, reg::Number; Z0 = pinv(D) * X, niter = 100)
	(N,K) = size(D)
	R = X - D * Z0 # residual vectors
	Z = copy(Z0)

	for iter=1:niter # outer loop over iteration
		for k=1:K # inner loop over atoms
			dk = @view D[:,k]
			R .+= dk * Z[k,:]'
			Z[k,:] = hard.(dk'*R, reg)
			R .-= dk * Z[k,:]'
		end
	end
	return Z
end

In [None]:
# apply BCM sparse coding to test data using the learned dictionary
Ztest = sparse_code_cd(Ytest, D, reg/2)
;

In [None]:
# see how well the sparse coding works on test data with learned D
@show norm(D * Ztest - Xtest) / norm(Xtest), sig
@show count(Z .!= 0) / length(Z)
ls = [3, 6, 8, 12]
scatter(Xtest[:,ls], color=:blue, marker=:square, label="true")
scatter!(D * Ztest[:,ls], color=:red, label="fit")

In [None]:
#jim(Xtest[:,1:100]')

### Optional extensions if time permits

* Experiment with the regularization parameter(s) to try to improve the results.
* Try other ways of initializing the dictionary D0.
* Try larger values of K since in practice the appropriate K is unknown.
* Compare hard thresholding (0-norm) to soft thresholding (1-norm).
* Compare the 2-block BCM to the 2N-block BCM used here.