{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Hand-written digit classification\n", "2018-10-23 Jeff Fessler, University of Michigan " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# packages needed\n", "using Plots\n", "using LinearAlgebra # svd, norm, etc.\n", "using Statistics: mean\n", "using Random" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the data and look at it " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "size(d0) = (28, 27, 1000)\n" ] } ], "source": [ "# read the MNIST data file for 0 and 1 digits\n", "# download from web if needed\n", "file0 = \"data4\"\n", "file1 = \"data9\"\n", "if !isfile(file0)\n", " download(\"http://cis.jhu.edu/~sachin/digit/\" * file0, file0)\n", "end\n", "if !isfile(file1)\n", " download(\"http://cis.jhu.edu/~sachin/digit/\" * file1, file1)\n", "end\n", "\n", "nx = 28 # original image size\n", "ny = 28\n", "nrep = 1000\n", "\n", "d0 = Array{UInt8}(undef, (nx,ny,nrep))\n", "read!(file0, d0) # load images\n", "\n", "d1 = Array{UInt8}(undef, (nx,ny,nrep))\n", "read!(file1, d1) # load images\n", "\n", "iy = 2:ny\n", "d0 = d0[:,iy,:] # Make images non-square to help debug\n", "d1 = d1[:,iy,:]\n", "ny = length(iy)\n", "\n", "# Convert images to Float32 to avoid overflow errors\n", "d0 = Array{Float32}(d0)\n", "d1 = Array{Float32}(d1)\n", "\n", "@show size(d0);" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "imshow3 (generic function with 1 method)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# function to display mosaic of multiple images\n", "function imshow3(x)\n", " tmp = permutedims(x, [1, 3, 2])\n", " tmp = reshape(tmp, :, ny)\n", " return heatmap(1:size(tmp,1), 1:ny, tmp,\n", " xtick=[1,nx], ytick=[1,ny], yflip=true,\n", " color=:grays, transpose=true, aspect_ratio=1)\n", "end" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# look at sorted and unsorted images to show (un)supervised\n", "Random.seed!(0)\n", "nrow = 4\n", "ncol = 6\n", "t0 = d0[:,:,1:Int(nrow*ncol/2)]\n", "t0[:,:,6] = d0[:,:,222] # include one ambiguous case\n", "t1 = d1[:,:,1:Int(nrow*ncol/2)]\n", "tmp = cat(t0, t1, dims=3)\n", "\n", "tmp = tmp[:,:,randperm(size(tmp,3))] # for unsupervised\n", "\n", "pl = []\n", "for ii=1:nrow\n", " p = imshow3(tmp[:,:,(1:ncol) .+ (ii-1)*ncol])\n", " plot!(p, colorbar=:none)\n", " for jj=1:(ncol-1)\n", " c = :yellow # unsup\n", "# c = ii <= nrow/2 ? :blue : :red\n", " plot!([1; 1]*jj*nx, [1; ny], label=\"\", color=c, xtick=[], ytick=[], axis=:off)\n", " end\n", " push!(pl, p)\n", "end\n", "plot(pl..., layout=(nrow,1))\n", "#savefig(\"02-digit-rand.pdf\")\n", "#savefig(\"02-digit-sort.pdf\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# use some data for training, and some for test\n", "ntrain = 100\n", "ntest = nrep - ntrain\n", "train0 = d0[:,:,1:ntrain] # training data\n", "train1 = d1[:,:,1:ntrain]\n", "test0 = d0[:,:,(ntrain+1):end] # testing data\n", "test1 = d1[:,:,(ntrain+1):end];" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "1\n", "\n", "\n", "28\n", "\n", "\n", "1\n", "\n", "\n", "27\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "-\n", "\n", "\n", "0.15\n", "\n", "\n", "-\n", "\n", "\n", "0.10\n", "\n", "\n", "-\n", "\n", "\n", "0.05\n", "\n", "\n", "0\n", "\n", "\n", "0.05\n", "\n", "\n", "0.10\n", "\n", "\n", "0.15\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "1\n", "\n", "\n", "28\n", "\n", "\n", "1\n", "\n", "\n", "27\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "-\n", "\n", "\n", "0.15\n", "\n", "\n", "-\n", "\n", "\n", "0.10\n", "\n", "\n", "-\n", "\n", "\n", "0.05\n", "\n", "\n", "0\n", "\n", "\n", "0.05\n", "\n", "\n", "0.10\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# svd for singular vectors and low-rank subspace approximation\n", "u0, _, _ = svd(reshape(train0, nx*ny, :))\n", "u1, _, _ = svd(reshape(train1, nx*ny, :))\n", "r0 = 3 # selected ranks\n", "r1 = 3\n", "q0 = reshape(u0[:,1:r0], nx, ny, :)\n", "q1 = reshape(u1[:,1:r1], nx, ny, :)\n", "p0 = imshow3(q0)\n", "p1 = imshow3(q1)\n", "plot(p0, p1, layout=(2,1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Examine how well the first left singular vectors separate the two classes " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "-2500\n", "\n", "\n", "-2000\n", "\n", "\n", "-1500\n", "\n", "\n", "-1000\n", "\n", "\n", "-2500\n", "\n", "\n", "-2000\n", "\n", "\n", "-1500\n", "\n", "\n", "-1000\n", "\n", "\n", "-500\n", "\n", "\n", "data4 U[:,1]\n", "\n", "\n", "data9 U[:,1]\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "data4\n", "\n", "\n", "\n", "\n", "data9\n", "\n", "\n" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "regress = (data, u) -> mapslices(slice -> u'*slice[:], data, dims=(1,2))[:]\n", " scatter(regress(train0, u0[:,1]), regress(train0, u1[:,1]), label=file0)\n", "scatter!(regress(train1, u0[:,1]), regress(train1, u1[:,1]), label=file1)\n", "plot!(xlabel = file0 * \" U[:,1]\", ylabel = file1 * \" U[:,1]\", legend=:topleft)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Classify test digits based on nearest subspace " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sum(correct0) / ntest = 0.7722222222222223\n", "sum(correct1) / ntest = 0.9311111111111111\n" ] }, { "data": { "text/plain": [ "0.9311111111111111" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q0 = reshape(q0, nx*ny, r0)\n", "Q1 = reshape(q1, nx*ny, r1)\n", "\n", "y0 = reshape(test0, nx*ny, :)\n", "y00 = y0 - Q0 * (Q0' * y0)\n", "y01 = y0 - Q1 * (Q1' * y0)\n", "correct0 = (mapslices(norm, y00, dims=1) .< mapslices(norm, y01, dims=1))[:]\n", "@show sum(correct0) / ntest\n", "\n", "y1 = reshape(test1, nx*ny, :)\n", "y10 = y1 - Q0 * (Q0' * y1)\n", "y11 = y1 - Q1 * (Q1' * y1)\n", "correct1 = (mapslices(norm, y10, dims=1) .> mapslices(norm, y11, dims=1))[:]\n", "@show sum(correct1) / ntest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### If I had more time I would show CNN-based classification here..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.0.1", "language": "julia", "name": "julia-1.0" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.0.1" } }, "nbformat": 4, "nbformat_minor": 2 }