{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hand-written digit classification\n",
"2018-10-23 Jeff Fessler, University of Michigan "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# packages needed\n",
"using Plots\n",
"using LinearAlgebra # svd, norm, etc.\n",
"using Statistics: mean\n",
"using Random"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load the data and look at it "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size(d0) = (28, 27, 1000)\n"
]
}
],
"source": [
"# read the MNIST data file for 0 and 1 digits\n",
"# download from web if needed\n",
"file0 = \"data4\"\n",
"file1 = \"data9\"\n",
"if !isfile(file0)\n",
" download(\"http://cis.jhu.edu/~sachin/digit/\" * file0, file0)\n",
"end\n",
"if !isfile(file1)\n",
" download(\"http://cis.jhu.edu/~sachin/digit/\" * file1, file1)\n",
"end\n",
"\n",
"nx = 28 # original image size\n",
"ny = 28\n",
"nrep = 1000\n",
"\n",
"d0 = Array{UInt8}(undef, (nx,ny,nrep))\n",
"read!(file0, d0) # load images\n",
"\n",
"d1 = Array{UInt8}(undef, (nx,ny,nrep))\n",
"read!(file1, d1) # load images\n",
"\n",
"iy = 2:ny\n",
"d0 = d0[:,iy,:] # Make images non-square to help debug\n",
"d1 = d1[:,iy,:]\n",
"ny = length(iy)\n",
"\n",
"# Convert images to Float32 to avoid overflow errors\n",
"d0 = Array{Float32}(d0)\n",
"d1 = Array{Float32}(d1)\n",
"\n",
"@show size(d0);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"imshow3 (generic function with 1 method)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# function to display mosaic of multiple images\n",
"function imshow3(x)\n",
" tmp = permutedims(x, [1, 3, 2])\n",
" tmp = reshape(tmp, :, ny)\n",
" return heatmap(1:size(tmp,1), 1:ny, tmp,\n",
" xtick=[1,nx], ytick=[1,ny], yflip=true,\n",
" color=:grays, transpose=true, aspect_ratio=1)\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# look at sorted and unsorted images to show (un)supervised\n",
"Random.seed!(0)\n",
"nrow = 4\n",
"ncol = 6\n",
"t0 = d0[:,:,1:Int(nrow*ncol/2)]\n",
"t0[:,:,6] = d0[:,:,222] # include one ambiguous case\n",
"t1 = d1[:,:,1:Int(nrow*ncol/2)]\n",
"tmp = cat(t0, t1, dims=3)\n",
"\n",
"tmp = tmp[:,:,randperm(size(tmp,3))] # for unsupervised\n",
"\n",
"pl = []\n",
"for ii=1:nrow\n",
" p = imshow3(tmp[:,:,(1:ncol) .+ (ii-1)*ncol])\n",
" plot!(p, colorbar=:none)\n",
" for jj=1:(ncol-1)\n",
" c = :yellow # unsup\n",
"# c = ii <= nrow/2 ? :blue : :red\n",
" plot!([1; 1]*jj*nx, [1; ny], label=\"\", color=c, xtick=[], ytick=[], axis=:off)\n",
" end\n",
" push!(pl, p)\n",
"end\n",
"plot(pl..., layout=(nrow,1))\n",
"#savefig(\"02-digit-rand.pdf\")\n",
"#savefig(\"02-digit-sort.pdf\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# use some data for training, and some for test\n",
"ntrain = 100\n",
"ntest = nrep - ntrain\n",
"train0 = d0[:,:,1:ntrain] # training data\n",
"train1 = d1[:,:,1:ntrain]\n",
"test0 = d0[:,:,(ntrain+1):end] # testing data\n",
"test1 = d1[:,:,(ntrain+1):end];"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# svd for singular vectors and low-rank subspace approximation\n",
"u0, _, _ = svd(reshape(train0, nx*ny, :))\n",
"u1, _, _ = svd(reshape(train1, nx*ny, :))\n",
"r0 = 3 # selected ranks\n",
"r1 = 3\n",
"q0 = reshape(u0[:,1:r0], nx, ny, :)\n",
"q1 = reshape(u1[:,1:r1], nx, ny, :)\n",
"p0 = imshow3(q0)\n",
"p1 = imshow3(q1)\n",
"plot(p0, p1, layout=(2,1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Examine how well the first left singular vectors separate the two classes "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"