šŸ“Ž

Matrix-Matrix multiplication

seešŸ“Ž Matrix-Vector multiplication for detailed explaination

(abcdef)ā‹…(ghijkl)=(gh)ā‹…(ad)+(ij)ā‹…(be)+(kl)ā‹…(cf) \small \begin{pmatrix} a & b & c\\ d & e & f \end{pmatrix} \cdot \begin{pmatrix} g & h\\ i & j \\ k & l \end{pmatrix} = \begin{pmatrix} g & h\end{pmatrix} \cdot \begin{pmatrix} a \\ d\end{pmatrix} + \begin{pmatrix} i & j\end{pmatrix} \cdot \begin{pmatrix} b \\ e\end{pmatrix} + \begin{pmatrix} k & l\end{pmatrix} \cdot \begin{pmatrix} c \\ f\end{pmatrix} ļ»æ

Sequential solution in O(nĀ³)\small O(nĀ³)ļ»æ

Assume A,B,C\small A,B,Cļ»æ are nƗn\small n\times nļ»æ matrices.

and also nā‰«p,pāˆ£n\small n \gg p, \quad p|nļ»æ

MPI processes organised into pƗp\small \sqrt p \times \sqrt pļ»æ matrices through MPI_Cart_create .

Each process (i,j)\small (i,j)ļ»æ contains submatrices Aā€²,Bā€²,Cā€²\small A', B', C'ļ»æ

There are

p\small \sqrt pļ»æ communicators for rows

p\small \sqrt pļ»æ communicators for columns

Folklore: Blockwise algorithm

Analysis

Work: O(n3/p+n2/p)\small O(n^3 /p+n^2 /\sqrt p)ļ»æ

Communication: 2ā‹…(logā”2(p)ā‹…Ī±+Ī²ā‹…(pāˆ’1)ā‹…n2/p) \small 2\cdot (\log_2( \sqrt p) \cdot Ī± + Ī² \cdot (\sqrt{p}-1) \cdot n^2 /p) ļ»æ

Space: pā‹…nĀ²\small \sqrt p \cdot nĀ²ļ»æ

very inefficient - increases by factor p\small \sqrt pļ»æ over sequential algorithm

Speedup: p\small pļ»æ when pāˆˆO(n2)\small p \in O(n^2)ļ»æ

q=p\small q = \sqrt pļ»æ

Process (i,j)\small (i,j)ļ»æ computes Cijā€²=(Ai,0ā€²,Ai,1ā€²,ā€¦Ai,qā€²)ā‹…(B0,jā€²,B1,jā€²,ā€¦,Bq,jā€²) \small C'_{i_j} = \left(A'_{i, 0}, A'_{i, 1}, \ldots A'_{i, q}\right) \cdot \left(B'_{0, j}, B'_{1, j}, \ldots, B'_{q, j}\right) ļ»æ

Algorithm

  1. Get data

    AllgatherAi,āˆ—\small A_{i,*}ļ»æ on rows

    AllgatherAāˆ—,j\small A_{*,j}ļ»æ on columns

  1. Locally compute (sequentially)

    Ci,j=āˆ‘1=0pāˆ’1Ai,1ā‹…B1,j \small C_{i, j}=\sum_{1=0}^{\sqrt{p}-1} A_{i, 1} \cdot B_{1, j} ļ»æ

Scalable Universal Matrix-Multiplication Algorithm (SUMMA)

Analysis

Work: O(n3/p+n2/p)\small O(n^3 /p+n^2 /\sqrt p)ļ»æ

Communication: 2ā‹…pā‹…(logā”2(p)ā‹…Ī±+Ī²ā‹…n2/p) \small 2\cdot \sqrt p \cdot (\log_2( \sqrt p) \cdot Ī± + Ī² \cdot n^2 /p) ļ»æ

Space: 2ā‹…nĀ²/p\small 2 \cdot nĀ² / pļ»æ

Speedup: p\small pļ»æ when pāˆˆO(n2)\small p \in O(n^2)ļ»æ

Idea: Pipelining Allgather operations

Ci,j\small C_{i, j}ļ»æ gets computed in p\small \sqrt pļ»æ communication rounds.

Main algorithm

In round lāˆˆ[0;pāˆ’1]\small l \in [0;~\sqrt p -1]ļ»æ we

  1. broadcoast Ai,1\small A_{i,1}ļ»æ in the row communicator
  1. broadcoast Bi,j\small B_{i,j}ļ»æ in the column communicator
  1. Update C+=Aā‹…B\small C \texttt{ += } A \cdot Bļ»æ