Linear systems of equations are the basic tool in NLA.
They appear as:
From school we know about linear equations.
A linear system of equations can be written in the form
or simply
$$ A u = f, $$where $A$ is a $3 \times 3$ matrix and $f$ is right-hand side
If the system $Au = f$ has
more equations than unknowns it is called overdetermined system (generically, no solution)
less equations than unknowns it is called underdetermined system (solution is non-unique, to make it unique additional assumptions have to be made)
A solution to the linear system of equations with a square matrix $A$
$$A u = f$$exists, iff
or
In different applications, the typical size of the linear systems can be different.
The main difficulty is that these systems are big: millions or billions of unknowns!
Q: how to work with such matrices?
A: fortunately, those matrices are structured and require $\mathcal{O}(N)$ parameters to be stored.
The most widespread structure are sparse matrices: such matrices have only $\mathcal{O}(N)$ non-zeros!
Example (one of the famous matrices around for $n = 5$):
Important: forget about determinants and the Cramer rule (it is good for $2 \times 2$ matrices still)!
The main tool is variable elimination.
\begin{align*} &2 y + 3 x = 5 \quad&\longrightarrow \quad &y = 5/2 - 3/2 x \\ &2 x + 3z = 5 \quad&\longrightarrow\quad &z = 5/3 - 2/3 x\\ &z + y = 2 \quad&\longrightarrow\quad & 5/2 + 5/3 - (3/2 + 2/3) x = 2,\\ \end{align*}and that is how you find $x$ (and all previous ones).
This process is called Gaussian elimination and is one of the most widely used algorithms.
and then substitute this into the equations $2, \ldots, n$.
Then we eliminate $x_2$ and so on from the second equation.
The important thing is that the pivots (that we divide over) are not equal to $0$.
In the backward step:
Definition: LU-decomposition of the square matrix $A$ is the representation
$$A = LU,$$where
This factorization is non-unique, so it is typical to require that the matrix $L$ has ones on the diagonal.
Main goal of the LU decomposition is to solve linear system, because
$$ A^{-1} f = (L U)^{-1} f = U^{-1} L^{-1} f, $$and this reduces to the solution of two linear systems forward step
$$ L y = f, $$and backward step
$$ U x = y. $$Does $LU$ decomposition always exist?
Each elimination step requires $\mathcal{O}(n^2)$ operations.
Thus, the cost of the naive algorithm is $\mathcal{O}(n^3)$.
Think a little bit: can Strassen algorithm help here?
We can try to compute block version of LU-decomposition:
$$\begin{pmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{pmatrix} = \begin{pmatrix} L_{11} & 0 \\ L_{21} & L_{22} \end{pmatrix} \begin{pmatrix} U_{11} & U_{12} \\ 0 & U_{22} \end{pmatrix} $$Q: when it is so, for which class of matrices?
A: it is true for strictly regular matrices.
Definition. A matrix $A$ is called strictly regular, if all of its leading principal minors (i.e, submatrices in the first $k$ rows and $k$ columns) are non-singular.
In this case, there always exists an LU-decomposition. The reverse is also true (check!).
Corollary: If $L$ is unit triangular (ones on the diagonal), then $LU$-decomposition is unique.
Proof: Indeed, $L_1 U_1 = L_2 U_2$ means $L_2^{-1} L_1 = U_2 U_1^{-1}$. $L_2^{-1} L_1 $ is lower triangular with ones on the diagonal. $U_2 U_1^{-1}$ is upper triangular. Thus, $L_2^{-1} L_1 = U_2 U_1^{-1} = I$ and $L_1 = L_2$, $U_1 = U_2$.
Strictly regular matrices have LU-decomposition.
An important subclass of strictly regular matrices is the class of Hermitian positive definite matrices
Definition. A matrix $A$ is called positive definite if for any $x: \Vert x \Vert \ne 0$ we have
$$ (x, Ax) > 0. $$where $R$ is a lower triangular matrix.
Let us try to prove this fact (on the whiteboard).
It is sometimes referred to as "square root" of the matrix.
In many cases, computing LU-decomposition once is a good idea!
Once the decomposition is found (it costs $\mathcal{O}(n^3)$ operations), then solving linear systems with $L$ and $U$ costs only $\mathcal{O}(n^2)$ operations.
Check:
What happens, if the matrix is not strictly regular (or the pivots in the Gaussian elimination are really small?).
There is classical $2 \times 2$ example of a matrix with a bad LU decomposition.
The matrix we look at is
Let us do some demo here.
import jax.numpy as jnp
import jax
from jax.config import config
config.update("jax_enable_x64", True)
eps = 1e-4#1.12e-16
a = [[eps, 1],[1.0, 1]]
a = jnp.array(a)
a0 = a.copy()
n = a.shape[0]
L = jnp.zeros((n, n))
U = jnp.zeros((n, n))
for k in range(n): #Eliminate one row
L = jax.ops.index_update(L, jax.ops.index[k, k], 1)
for i in range(k+1, n):
L = jax.ops.index_update(L, jax.ops.index[i, k], a[i, k] / a[k, k])
for j in range(k+1, n):
a = jax.ops.index_add(a, jax.ops.index[i, j], -L[i, k] * a[k, j])
for j in range(k, n):
U = jax.ops.index_update(U, jax.ops.index[k, j], a[k, j])
print('L * U - A:\n', jnp.dot(L, U) - a0)
L
L * U - A: [[0. 0.] [0. 0.]]
DeviceArray([[1.e+00, 0.e+00], [1.e+04, 1.e+00]], dtype=float64)
We can do pivoting, i.e. permute rows and columns to maximize $A_{kk}$ that we divide over.
The simplest but effective strategy is the row pivoting: at each step, select the index that is maximal in modulus, and put it onto the diagonal.
It gives us the decomposition
where $P$ is a permutation matrix.
Q. What makes row pivoting good?
A. It is made good by the fact that
$$ | L_{ij}|<1, $$but the elements of $U$ can grow, up to $2^n$! (in practice, this is very rarely encountered).
There is a fundamental problem of solving linear systems which is independent on the algorithm used.
It occures when elements of a matrix are represented as floating point numbers or there is some measurement noise.
Let us illustrate this issue on the following example.
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
%matplotlib inline
n = 33
a = [[1.0/(i + j + 1) for i in range(n)] for j in range(n)]
a = jnp.array(a)
rhs = jax.random.normal(jax.random.PRNGKey(0), (n,)) #Right-hand side
x = jnp.linalg.solve(a, rhs) #This function computes LU-factorization and solves linear system
#And check if everything is fine
er = jnp.linalg.norm(a.dot(x) - rhs) / jnp.linalg.norm(rhs)
print(er)
plt.plot(x)
plt.grid(True)
11.1598010212843
What was the problem in the previous example?
Why the error grows so quickly?
And here is one of the main concepts of numerical linear algebra: the concept of condition number of a matrix.
But before that we have to define the inverse.
where $I$ is the identity matrix (i.e., $I_{ij} = 0$ if $i \ne j$ and $1$ otherwise).
where $e_i$ is the $i$-th column of the identity matrix.
If we have computed $A^{-1}$, the solution of linear system
$$Ax = f$$is just $x = A^{-1} f$.
Indeed,
$$ A(A^{-1} f) = (AA^{-1})f = I f = f. $$Neumann series:
If a matrix $F$ is such that $\Vert F \Vert < 1$ holds, then the matrix $(I - F)$ is invertible and
$$(I - F)^{-1} = I + F + F^2 + F^3 + \ldots = \sum_{k=0}^{\infty} F^k.$$Note that it is a matrix version of the geometric progression.
Q: what norm is considered here? What is the "best possible" norm here?
The proof is constructive. First of all, prove that the series $\sum_{k=0}^{\infty} F^k$ converges.
Like in the scalar case, we have
$$ (I - F) \sum_{k=0}^N F^k = (I - F^{N+1}) \rightarrow I, \quad N \to +\infty $$Indeed,
$$ \| (I - F^{N+1}) - I\| = \|F^{N+1}\| \leqslant \|F\|^{N+1} \to 0, \quad N\to +\infty. $$We can also estimate the norm of the inverse:
$$ \left\Vert \sum_{k=0}^N F^k \right\Vert \leq \sum_{k=0}^N \Vert F \Vert^k \Vert I \Vert \leq \frac{\Vert I \Vert}{1 - \Vert F \Vert} $$and moreover,
$$ \frac{\Vert (A + E)^{-1} - A^{-1} \Vert}{\Vert A^{-1} \Vert} \leq \frac{\Vert A^{-1} \Vert \Vert E \Vert \Vert I \Vert}{1 - \Vert A^{-1} E \Vert}. $$As you see, the norm of the inverse enters the estimate.
Now consider the perturbed linear system:
$$ (A + \Delta A) \widehat{x} = f + \Delta f. $$therefore
$$ \begin{split} \frac{\Vert \widehat{x} - x \Vert}{\Vert x \Vert} \leq &\frac{1}{\|A^{-1}f\|} \Big[ \frac{\|A^{-1}\|\|\Delta A\|}{1 - \|A^{-1}\Delta A\|}\|A^{-1}f\| + \frac{1}{1 - \|A^{-1} \Delta A\|} \|A^{-1} \Delta f\| \Big] \\ \leq & \frac{\|A\|\|A^{-1}\|}{1 - \|A^{-1}\Delta A\|} \frac{\|\Delta A\|}{\|A\|} + \frac{\|A^{-1}\|}{1 - \|A^{-1}\Delta A\|} \frac{\|\Delta f\|}{\|A^{-1}f\|}\\ \end{split} $$Note that $\|AA^{-1}f\| \leq \|A\|\|A^{-1}f\|$, therefore $\| A^{-1} f \| \geq \frac{\|f\|}{\|A\|}$
Now we are ready to get the final estimate
$$ \begin{split} \frac{\Vert \widehat{x} - x \Vert}{\Vert x \Vert} \leq &\frac{\Vert A \Vert \Vert A^{-1} \Vert}{1 - \|A^{-1}\Delta A\|} \Big(\frac{\Vert\Delta A\Vert}{\Vert A \Vert} + \frac{\Vert \Delta f \Vert}{ \Vert f \Vert}\Big) \leq \\ \leq &\frac{\Vert A \Vert \Vert A^{-1} \Vert}{1 - \|A\|\|A^{-1}\|\frac{\|\Delta A\|}{\|A\|}} \Big(\frac{\Vert\Delta A\Vert}{\Vert A \Vert} + \frac{\Vert \Delta f \Vert}{ \Vert f \Vert}\Big) \equiv \\ \equiv &\frac{\mathrm{cond}(A)}{1 - \mathrm{cond}(A)\frac{\|\Delta A\|}{\|A\|}} \Big(\frac{\Vert\Delta A\Vert}{\Vert A \Vert} + \frac{\Vert \Delta f \Vert}{ \Vert f \Vert}\Big) \end{split} $$The crucial role is played by the condition number $\mathrm{cond}(A) = \Vert A \Vert \Vert A^{-1} \Vert$.
The larger the condition number, the less number of digits we can recover. Note, that the condition number is different for different norms.
Note, that if $\Delta A = 0$, then
import jax.numpy as jnp
import matplotlib.pyplot as plt
%matplotlib inline
n = 50
a = [[1.0/(i + j + 0.01) for i in range(n)] for j in range(n)]
a = jnp.array(a)
rhs = jax.random.normal(jax.random.PRNGKey(10), [n])
#rhs = jnp.ones(n) #Right-hand side
f = jnp.linalg.solve(a, rhs)
#And check if everything is fine
er = jnp.linalg.norm(a.dot(f) - rhs) / jnp.linalg.norm(rhs)
cn = jnp.linalg.cond(a, 2)
print('Error:', er, 'Log Condition number:', jnp.log10(cn))
u1, s1, v1 = jnp.linalg.svd(a)
cf = u1.T@rhs
cf/s1
Error: 8.981202871503108 Log Condition number: 21.341344339184218
DeviceArray([-2.59701783e-03, -3.62615992e-01, -4.92682651e+00, -1.39317220e+00, -8.60041880e+01, 3.03460550e+02, 9.20961751e+03, -3.41310628e+04, -8.24928647e+04, 5.27596100e+06, 3.30516690e+07, 5.81349493e+08, 3.30494319e+09, 5.91765636e+10, 1.54937154e+11, -1.03870175e+13, 6.97324076e+13, -1.97562227e+13, -5.47702944e+13, 1.94494972e+14, 2.37939836e+13, -6.35816689e+13, -1.25422856e+14, 5.42975038e+13, 7.27578993e+12, -3.22040681e+13, 1.57513412e+14, -1.41688598e+14, -1.34967174e+14, -3.54499321e+13, 3.86080313e+13, -3.02898541e+13, 5.14933533e+13, -1.09592987e+14, -1.92757750e+14, 2.72745301e+13, 7.99814548e+13, 6.06250212e+13, -1.03983362e+14, -8.07437378e+13, 7.07436338e+13, -1.43980174e+14, 4.39149056e+13, -8.82456882e+13, -1.58302961e+13, -1.13785201e+14, 1.38974003e+11, -8.10241429e+13, -6.89874139e+13, 7.33904252e+13], dtype=float64)
And with random right-hand side...
import jax.numpy as jnp
import matplotlib.pyplot as plt
%matplotlib inline
n = 100
a = [[1.0/(i + j + 1) for i in range(n)] for j in range(n)]
a = jnp.array(a)
rhs = jax.random.normal(jax.random.PRNGKey(-1), (n, )) #Right-hand side
f = jnp.linalg.solve(a, rhs)
#And check if everything is fine
er = jnp.linalg.norm(a.dot(f) - rhs) / jnp.linalg.norm(rhs)
cn = jnp.linalg.cond(a)
print('Error:', er, 'Condition number:', cn)
u, s, v = jnp.linalg.svd(a)
rhs = jax.random.normal(jax.random.PRNGKey(1), (n, ))
# rhs = jnp.ones((n,))
plt.plot(u.T.dot(rhs))
plt.grid(True)
plt.xlabel("Index of vector elements", fontsize=20)
plt.ylabel("Elements of vector", fontsize=20)
plt.xticks(fontsize=18)
_ = plt.yticks(fontsize=18)
Error: 17.674315761144477 Condition number: 4.073996146476839e+19
Can you think about an explanation?
Important class of problems are overdetermined linear systems, when the number of equations is greater, than the number of unknowns.
The simplest example that you all know, is linear fitting, fitting a set of 2D points by a line.
Then, a typical way is to minimize the residual (least squares)
The optimality condition is $0\equiv \nabla \left(\|Ax-b\|_2^2\right)$, where $\nabla$ denotes gradient. Therefore,
$$ 0 \equiv \nabla \left(\|Ax-b\|_2^2\right) = 2(A^*A x - A^*b) = 0. $$Thus,
$$ \quad A^* A x = A^* b $$The matrix $A^* A$ is called Gram matrix and the system is called normal equation.
This is not a good way to do it, since the condition number of $A^* A$ is a square of condition number of $A$ (check why).
The matrix $$A^{\dagger} = \lim_{\alpha \rightarrow 0}(\alpha I + A^* A)^{-1} A^*$$ is called Moore-Penrose pseudoinverse of the matrix $A$.
If matrix $A$ has full column rank, then $A^* A$ is non-singular and we get
Let $A = U \Sigma V^*$ be the SVD of $A$. Then,
$$A^{\dagger} = V \Sigma^{\dagger} U^*,$$where $\Sigma^{\dagger}$ consists of inverses of non-zero singular values of $A$. Indeed,
\begin{align*} A^{\dagger} &= \lim_{\alpha \rightarrow 0}(\alpha I + A^* A)^{-1} A^* = \lim_{\alpha \rightarrow 0}( \alpha VV^* + V \Sigma^2 V^*)^{-1} V \Sigma U^* \\ & = \lim_{\alpha \rightarrow 0}( V(\alpha I + \Sigma^2) V^*)^{-1} V \Sigma U^* = V \lim_{\alpha \rightarrow 0}(\alpha I + \Sigma^2)^{-1} \Sigma U^* = V \Sigma^{\dagger} U^*. \end{align*}Is to use the $QR$ decomposition.
where $Q$ is unitary, and $R$ is upper triangular (details in the next lectures).
or in the block form
$$ \begin{pmatrix} 0 & A^* \\ A & -I \end{pmatrix} \begin{pmatrix} x \\ r \end{pmatrix} = \begin{pmatrix} 0 \\ b \end{pmatrix}, $$the total size of the system is $(n + m)$ square, and the condition number is the same as for $A$
Consider a two-dimensional example. Suppose we have a linear model
$$y = ax + b$$and noisy data $(x_1, y_1), \dots (x_n, y_n)$. Then the linear system on coefficients will look as follows
$$ \begin{split} a x_1 &+ b &= y_1 \\ &\vdots \\ a x_n &+ b &= y_n \\ \end{split} $$or in a matrix form
$$ \begin{pmatrix} x_1 & 1 \\ \vdots & \vdots \\ x_n & 1 \\ \end{pmatrix} \begin{pmatrix} a \\ b \end{pmatrix} = \begin{pmatrix} y_1 \\ \vdots \\ y_n \\ \end{pmatrix}, $$which represents overdetermined system.
%matplotlib inline
import jax.numpy as jnp
import matplotlib.pyplot as plt
a_exact = 1.
b_exact = 2.
n = 10
xi = jnp.arange(n)
yi = a_exact * xi + b_exact + 2 * jax.random.normal(jax.random.PRNGKey(1), (n, ))
A = jnp.array([xi, jnp.ones(n)])
coef = jnp.linalg.pinv(A).T.dot(yi) # coef is [a, b]
plt.plot(xi, yi, 'o', label='$(x_i, y_i)$')
plt.plot(xi, coef[0]*xi + coef[1], label='Least squares')
plt.legend(loc='best', fontsize=18)
plt.grid(True)
A typical 3D-problem requires a $100 \times 100 \times 100$ discretization
This gives a linear system with $10^6$ unknowns, right-hand side takes $8$ megabytes of memory
This matrix has $10^6 \times 10^6 = 10^{12}$ elements, takes $8$ terabytes of memory.
Fortunately, the matrices in real-life are not dense, but have certain structure:
from IPython.core.display import HTML
def css_styling():
styles = open("../styles/custom.css", "r").read()
return HTML(styles)
css_styling()