import numpy as np
def matmul(a, b):
n = a.shape[0]
k = a.shape[1]
m = b.shape[1]
c = np.zeros((n, m))
for i in range(n):
for j in range(m):
for s in range(k):
c[i, j] += a[i, s] * b[s, j]
return c
def matmul_isj(a, b):
n = a.shape[0]
k = a.shape[1]
m = b.shape[1]
c = np.zeros((n, m))
for i in range(n):
for s in range(k):
for j in range(m):
c[i, j] += a[i, s] * b[s, j]
return c
import numpy as np
from numba import jit # Just-in-time compiler for Python, see http://numba.pydata.org
@jit(nopython=True)
def numba_matmul(a, b):
n = a.shape[0]
k = a.shape[1]
m = b.shape[1]
c = np.zeros((n, m))
for i in range(n):
for j in range(m):
for s in range(k):
c[i, j] += a[i, s] * b[s, j]
return c
@jit(nopython=True)
def numba_matmul_isj(a, b):
n = a.shape[0]
k = a.shape[1]
m = b.shape[1]
c = np.zeros((n, m))
for i in range(n):
for s in range(k):
mult = a[i, s]
for j in range(m):
c[i, j] += mult * b[s, j]
return c
n = 500
a = np.random.randn(n, n)
b = np.random.randn(n, n)
# %timeit matmul(a, b)
# %timeit matmul_isj(a, b)
%timeit numba_matmul(a, b)
%timeit numba_matmul_isj(a, b)
%timeit np.dot(a, b)
1 loop, best of 5: 168 ms per loop The slowest run took 7.12 times longer than the fastest. This could mean that an intermediate result is being cached. 1 loop, best of 5: 38.2 ms per loop 100 loops, best of 5: 7.87 ms per loop
A = np.array([[1, 2, 3], [5, 2, 7]], order="C")
A.dtype
dtype('int64')
import torch
A = torch.from_numpy(a)
B = torch.from_numpy(b)
# How to call BLAS/LAPACK functions from torch
torch.__config__.show()
'PyTorch built with:\n - GCC 7.3\n - C++ Version: 201402\n - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - NNPACK is enabled\n - CPU capability usage: AVX2\n - CUDA Runtime 11.1\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n - CuDNN 8.0.5\n - Magma 2.5.2\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n'
Search section BLAS and LAPACK Operations in https://pytorch.org/docs/1.10.0/torch.html
You will have some strange acronyms like
geqrf
ger
ormqr
These names come from native LAPACK package and are standard for implementation of more complicated algorityhmes from the simple primitives.
%timeit A @ B
The slowest run took 10.87 times longer than the fastest. This could mean that an intermediate result is being cached. 100 loops, best of 5: 6.76 ms per loop
A_gpu = A.to("cuda")
B_gpu = B.to("cuda")
%timeit A_gpu @ B_gpu
The slowest run took 177.63 times longer than the fastest. This could mean that an intermediate result is being cached. 10000 loops, best of 5: 366 µs per loop
print(A)
tensor([[ 1.2076, 0.2286, -0.1727, ..., -1.3513, 0.6674, 0.7101], [-0.5689, 0.7661, 0.6888, ..., -1.0185, 0.2321, 1.0647], [-0.5084, 1.0221, 0.5718, ..., -1.0592, -2.0717, 0.5223], ..., [-0.1754, -0.1383, 1.0748, ..., -0.0469, -0.1058, -0.9498], [-0.6619, 0.5752, -0.1461, ..., 1.6014, -1.0753, 1.8088], [ 0.6649, 0.0372, -1.6637, ..., 0.2114, 0.2700, -0.6503]], dtype=torch.float64)
t = torch.randn((n, n))
print(t)
tensor([[ 0.8519, -0.8576, 1.4654, ..., -0.3948, -1.4546, -1.8130], [ 0.5388, 0.3280, -1.3981, ..., -2.4956, 1.3326, 0.1441], [-1.4787, -0.4034, 0.9097, ..., -1.2501, 0.2794, 0.7110], ..., [ 0.8301, -2.0607, -0.6245, ..., -0.5105, -1.2650, -1.3490], [-1.5284, 0.4654, 1.0024, ..., -0.1602, 1.3207, -0.6443], [-1.3054, 0.6352, -0.6033, ..., 0.4865, -1.4938, -1.9792]])
t.dtype
torch.float32
import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
A_jax = jnp.array(a)
B_jax = jnp.array(b)
%timeit (A_jax @ B_jax).block_until_ready()
1000 loops, best of 5: 457 µs per loop
A_jax.device_buffer.device()
GpuDevice(id=0, process_index=0)
A_jax.dtype
dtype('float64')