Files
cmdla/Lessons/10-27/lesson.ipynb
2024-07-30 14:43:25 +02:00

331 lines
8.1 KiB
Plaintext
Generated
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "40e2ecf6-a1ee-4d82-924a-e2f763915652",
"metadata": {},
"outputs": [],
"source": [
"using LinearAlgebra, Plots"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "89746093-dc10-4bb2-9646-c84d5db0d8f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"householder_vector (generic function with 2 methods)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function householder_vector(x::Vector{<:AbstractFloat})::Tuple{Vector, AbstractFloat}\n",
" # returns the normalized vector u such that H*x is a multiple of e_1\n",
"\n",
" s = norm(x)\n",
" if x[1] ≥ 0\n",
" s = -s\n",
" end\n",
" u = copy(x)\n",
" u[1] -= s\n",
" u ./= norm(u)\n",
" return u, s\n",
"end\n",
"\n",
"function householder_vector(x::Matrix{<:AbstractFloat})::Tuple{Matrix, AbstractFloat}\n",
" # returns the normalized vector u such that H*x is a multiple of e_1\n",
"\n",
" s = norm(x)\n",
" if x[1] ≥ 0\n",
" s = -s\n",
" end\n",
" u = copy(x)\n",
" u[1] -= s\n",
" u ./= norm(u)\n",
" return u, s\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "262a769a-aa42-4929-bbf4-7f5a97783810",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([0.7960091839647405, 0.3357514552548967, 0.503627182882345], -3.7416573867739413)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"3-element Vector{Float64}:\n",
" 1.0\n",
" 2.0\n",
" 3.0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x = [1., 2, 3]\n",
"householder_vector(x) |> display\n",
"x |> display\n",
"\n",
"# better with copy and division in place\n",
"# @benchmark householder_vector(randn(100_000))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "457b3bcf-a077-42cb-9a6c-f6d1d6e00504",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5-element Vector{Float64}:\n",
" -3.256101502501094\n",
" 3.4967785872994334e-17\n",
" 4.405003154616756e-17\n",
" -3.1688882202679996e-17\n",
" 2.3942585187350707e-16"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"A = randn(5, 4)\n",
"\n",
"# first step of QR factorization\n",
"R1 = A\n",
"(u1, s1) = householder_vector(R1[1:end,1])\n",
"\n",
"H1 = I - 2 * u1 * u1'\n",
"\n",
"Q1 = H1\n",
"\n",
"Q1 * R1[1:end, 1] |> display # what we expect -> a multiple of e_1"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fd98a89e-01b9-4403-8bca-b9e1443b2eea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×4 Matrix{Float64}:\n",
" -3.2561 0.130609 -0.788793 -0.0946472\n",
" -1.81593e-16 -1.31508 0.468296 0.501379\n",
" 3.27736e-17 -8.21875e-19 0.548462 -1.5704\n",
" 7.34914e-17 6.58877e-17 0.563737 -0.53087\n",
" 1.48462e-16 -1.40266e-16 -0.454505 0.883403"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# second step\n",
"R2 = Q1 * R1\n",
"\n",
"(u2, s2) = householder_vector(R2[2:end, 2])\n",
"H2 = I - 2 * u2 * u2'\n",
"\n",
"# there is no blkdiag method in julia\n",
"# (maybe look into https://github.com/JuliaArrays/BlockDiagonals.jl)\n",
"# there are 2 methods (blocks is an array of blocks):\n",
"### METHOD 1:\n",
"# cat(blocks..., dims=(1,2))\n",
"### METHOD 2:\n",
"# using SparseArrays\n",
"# blockdiag(SparseMatrixCSC.(blocks)...)\n",
"## method 2 is slightly faster with subsequent matrix multiplication\n",
"# performance is ignored in this step\n",
"Q2 = cat(1, H2, dims=(1, 2))\n",
"\n",
"Q2 * R2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a563f47e-7eda-46c7-9804-080335bcb8a3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×4 Matrix{Float64}:\n",
" -3.2561 0.130609 -0.788793 -0.0946472\n",
" -1.81593e-16 -1.31508 0.468296 0.501379\n",
" 8.88587e-18 -1.10573e-16 -0.908398 1.71961\n",
" 6.42479e-17 2.34191e-17 -1.02761e-16 0.742212\n",
" 1.55915e-16 -1.06026e-16 7.81748e-17 -0.143001"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# third step\n",
"\n",
"R3 = Q2 * R2\n",
"\n",
"(u3, s3) = householder_vector(R3[3:end, 3])\n",
"H3 = I - 2 * u3 * u3'\n",
"\n",
"Q3 = cat(Diagonal(ones(2)), H3, dims=(1,2))\n",
"Q3 * R3"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "85f5eb54-f3fe-40c8-86d0-90e80895b753",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×4 Matrix{Float64}:\n",
" -3.2561 0.130609 -0.788793 -0.0946472\n",
" -1.81593e-16 -1.31508 0.468296 0.501379\n",
" 8.88587e-18 -1.10573e-16 -0.908398 1.71961\n",
" -3.35902e-17 -4.30553e-17 1.15696e-16 -0.755862\n",
" 1.65254e-16 -9.96807e-17 5.73216e-17 7.46032e-18"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# fourth step\n",
"\n",
"R4 = Q3 * R3\n",
"\n",
"(u4, s4) = householder_vector(R4[4:end, 4])\n",
"H4 = I - 2 * u4 * u4'\n",
"\n",
"Q4 = cat(Diagonal(ones(3)), H4, dims=(1,2))\n",
"Q4 * R4\n",
"\n",
"# done because we arrived at the second dimension of A"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eff90e91-7856-4fd7-b2a5-79c0f681147a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"qrfactorization (generic function with 1 method)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function qrfactorization(A::Matrix{<:AbstractFloat})::Tuple{Matrix{<:AbstractFloat}, Matrix{<:AbstractFloat}}\n",
" (m, n) = size(A)\n",
" R = copy(A)\n",
" Q = Diagonal(ones(eltype(A), m))\n",
"\n",
" for k ∈ 1:n\n",
" (u, s) = householder_vector(R[k:end, k])\n",
" # construct R\n",
" R[k, k] = s\n",
" R[k+1:end, k] .= 0\n",
" R[k:end, k+1:end] -= 2 * u * (u' * R[k:end, k+1:end])\n",
" # contruct the new H\n",
" H = I - 2 * u * u'\n",
" # contruct the Q\n",
" Q = Q * cat(Diagonal(ones(eltype(A), k-1)), H, dims=(1,2)) # very inefficient (maybe simply send back the list of u_i)\n",
" end\n",
" return (Q, R)\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4dbe13ff-44f5-4bd2-8452-a9e1477c80ff",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"true"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"true"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"A = randn(Float32, 1000, 20)\n",
"(Q, R) = qrfactorization(A)\n",
"(norm(A - Q*R) ≤ size(A)[1] * 2^-23 * norm(A)) |> display\n",
"(norm(I - Q*Q') ≤ size(A)[1] * 2^-23) |> display"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.9.4",
"language": "julia",
"name": "julia-1.9"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}