{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c7935100",
   "metadata": {},
   "source": [
    "## OptShrink matrix denoising demo\n",
    "- 2017-11-16 Jeff Fessler, University of Michigan\n",
    "- 2023-05-21 Julia 1.9.0\n",
    "- 2023-09-22 Julia 1.11.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c533838c",
   "metadata": {},
   "outputs": [],
   "source": [
    "using LaTeXStrings\n",
    "using Random: seed!\n",
    "using Plots: default, gui, plot, savefig, scatter!\n",
    "using LinearAlgebra: svd, svdvals, norm, Diagonal, rank\n",
    "using MIRTjim: jim, prompt\n",
    "default(markersize=7, markerstrokecolor=:auto,\n",
    " labelfontsize = 18, legendfontsize = 18, tickfontsize = 14)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4f769b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# you must modify the following directory path:\n",
    "dircode = ENV[\"hw551test\"]\n",
    "include(joinpath(dircode, \"optshrink2.jl\"));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e16712c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# make a matrix that has low rank\n",
    "tmp = [\n",
    "    zeros(1,20);\n",
    "    0 1 0 0 0 0 1 0 0 0 1 1 1 1 0 1 1 1 1 0;\n",
    "    0 1 0 0 0 0 1 0 0 0 0 1 0 0 1 0 0 1 0 0;\n",
    "    0 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0;\n",
    "    0 0 1 1 1 1 0 0 0 0 1 1 0 0 0 0 0 1 1 0;\n",
    "    zeros(1,20)\n",
    "]';\n",
    "X = 80 * kron(tmp, ones(5,5));"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3791d15d",
   "metadata": {},
   "source": [
    "Q: What is the maximum rank of X?\n",
    "- A. 1\n",
    "- B. 2\n",
    "- C. 3\n",
    "- D. 4\n",
    "- E. 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0c4a76a",
   "metadata": {},
   "outputs": [],
   "source": [
    "clim = (0,80)\n",
    "cticks = [0,80]\n",
    "siz = (600, 260)\n",
    "args = ( ; clim, cticks, size = siz,\n",
    " xaxis = false, yaxis = false, colorbar = :none, # book\n",
    ")\n",
    "p1 = jim(X,\n",
    " L\"\\mathrm{Original\\ image\\ } X \\mathrm{\\ with\\ rank\\ } %$(rank(X))\";\n",
    " xlabel = \" \", args...,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6993cdb",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# savefig(p1, \"07_optshrink1x.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6024d17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Y is X + noise\n",
    "seed!(0)\n",
    "Z = 7 * randn(size(X))\n",
    "Y = X + Z\n",
    "nrmse = (Xh) -> norm(Xh - X) / norm(X) * 100\n",
    "nrmse_y = nrmse(Y)\n",
    "p2 = jim(Y,\n",
    " L\"\\mathrm{Noisy\\ image\\ } Y=X+Z \\mathrm{\\ with\\ rank\\ } %$(rank(Y))\";\n",
    " xlabel=\"NRMSE = $(round(nrmse_y, digits=1)) %\",\n",
    " args...,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1edfc202",
   "metadata": {},
   "outputs": [],
   "source": [
    "# savefig(p2, \"07_optshrink1y.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99aeaf4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rmax = minimum(size(X))\n",
    "ps = plot(title=\"Singular values\", widen=true,\n",
    "    xaxis = (L\"k\", (0,rmax), [1,rank(X),rmax]),\n",
    "    yaxis = (L\"\\sigma_k\", (0,1700), [0, 267, 746, 947, 1672]))\n",
    "scatter!(1:rmax, svdvals(Y), color=:red, label=L\"\\sigma_k(Y)\", marker=:hexagon)\n",
    "scatter!(1:rmax, svdvals(X), color=:blue, label=L\"\\sigma_k(X)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "983e351a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# savefig(ps, \"07_optshrink1s.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd51f349",
   "metadata": {},
   "outputs": [],
   "source": [
    "# try several ranks\n",
    "ranks = 1:15\n",
    "Xos = zeros(size(Y)..., length(ranks))\n",
    "Xlr = copy(Xos)\n",
    "nrmse_os = zeros(length(ranks))\n",
    "nrmse_lr = copy(nrmse_os)\n",
    "for (ir, r) in enumerate(ranks)\n",
    "    Xos[:,:,ir] = optshrink2(Y, r)\n",
    "    U,s,V = svd(Y)\n",
    "    Xlr[:,:,ir] = U[:,1:r] * Diagonal(s[1:r]) * V[:,1:r]'\n",
    "    nrmse_os[ir] = nrmse(Xos[:,:,ir])\n",
    "    nrmse_lr[ir] = nrmse(Xlr[:,:,ir])\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2548b2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ir = 4\n",
    "title = L\"\\mathrm{OptShrink\\ image\\ } \\hat{X} \\mathrm{\\ with\\ rank\\ guess\\ %$(ranks[ir])}\"\n",
    "p3 = jim(Xos[:,:,ir], title;\n",
    " xlabel=\"NRMSE = $(round(nrmse_os[ir], digits=1)) %\", args...\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "064a3a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "# savefig(p3, \"07_optshrink1g4.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a4ef101",
   "metadata": {},
   "outputs": [],
   "source": [
    "pn = plot(\n",
    "    xaxis = (\"Rank (estimate)\", extrema(ranks), [1,rank(X),maximum(ranks)]),\n",
    "    yaxis = (\"NRMSE %\", (0, 60), [0, 7, 60]), widen = true)\n",
    "scatter!(ranks, nrmse_lr, color=:magenta, label = \"LR NRMSE\")\n",
    "scatter!(ranks, nrmse_os, color=:darkgreen, label = \"OptShrink NRMSE\", marker=:star)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ebf5cfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# savefig(pn, \"07_optshrink1n.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a195fb86",
   "metadata": {},
   "outputs": [],
   "source": [
    "ir = 9\n",
    "title = L\"\\mathrm{OptShrink\\ image\\ } \\hat{X} \\mathrm{\\ with\\ rank\\ guess\\ %$(ranks[ir])}\"\n",
    "p4 = jim(Xos[:,:,ir], title; args...,\n",
    " xlabel = \"NRMSE = $(round(nrmse_os[ir], digits=1)) %\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1161c7da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# savefig(p4, \"07_optshrink1g9.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9db906",
   "metadata": {},
   "outputs": [],
   "source": [
    "rmax = minimum(size(X))\n",
    "ir = 9\n",
    "pss = deepcopy(ps)\n",
    "label = latexstring(\"\\\\mathrm{OptShrink:\\\\ } \\\\hat{r}=$ir\")\n",
    "scatter!(pss, 1:rmax, svdvals(Xos[:,:,ir]); color=:green, label, marker=:star)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b80732da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# savefig(pss, \"07_optshrink1o.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03e090d9",
   "metadata": {},
   "source": [
    "## Below here for HW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "070bb345",
   "metadata": {},
   "outputs": [],
   "source": [
    "if false # delete for HW\n",
    "include(dircode * \"lr_schatten.jl\") # for HW, put path to your lr_schatten here\n",
    "\n",
    "# try different regularization parameters\n",
    "for l2reg in 8:0.5:12\n",
    "    Xr = lr_schatten(Y, 2^l2reg)\n",
    "    @show l2reg, nrmse(Xr)\n",
    "end\n",
    "\n",
    "# Schatten LR with β=1000\n",
    "β = 1000\n",
    "Xs = lr_schatten(Y, β)\n",
    "@show nrmse(Xs)\n",
    "\n",
    "user = Sys.username()\n",
    "jim(Xs, \"Schatten LR image with p=1/2 and β = $β for $user\";\n",
    " size=(600,250), xlabel=\"NRMSE = $(round(nrmse(Xs), digits=1)) %\")\n",
    "\n",
    "#savefig(\"hsj68im.pdf\")\n",
    "\n",
    "psp = deepcopy(ps)\n",
    "scatter!(psp, 1:rmax, svdvals(Xs), color=:magenta, label=L\"\\mathrm{Schatten }\\ p=1/2\")\n",
    "\n",
    "#savefig(psp, \"hsj68svd.pdf\")\n",
    "end # delete for HW"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "encoding": "# -*- coding: utf-8 -*-"
  },
  "kernelspec": {
   "display_name": "Julia 1.11.7",
   "language": "julia",
   "name": "julia-1.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
