This page was generated from tools/engine/spmv.ipynb.

JIT Engine: Sparse Matrix x Dense Vector

Most of the previous tutorials have been focused on dense tensors. This tutorial will focus on sparse tensors.

In particular, this example will go over how to compile MLIR code aimed at multiplying a sparse matrix with a dense tensor into a function callable from Python.

Let’s first import some necessary modules and generate an instance of our JIT engine.


import mlir_graphblas
import mlir_graphblas.sparse_utils
import numpy as np

engine = mlir_graphblas.MlirJitEngine()
Using development graphblas-opt: /Users/pnguyen/code/mlir-graphblas/mlir_graphblas/src/build/bin/graphblas-opt

State of MLIR’s Current Sparse Tensor Support

MLIR’s sparse tensor support is in its early stages and is fairly limited as it is undergoing frequent development. For more details on what is currently being focused on, see the MLIR discussion on sparse tensors.

It currently has two noteworthy limitations:

  • MLIR’s sparse tensor functionality in the linalg dialect currently only supports reading from sparse tensors but not storing into sparse tensors. Thus, the functions we write can accept sparse tensors as inputs but will return dense tensors.

  • MLIR’s sparse tensor support only supports a limited number of sparse storage layouts.

This first tutorial will go over the details of MLIR’s sparse tensor support along with how to implement a function to multiply an MLIR sparse matrix with a dense vector to create a dense matrix.

MLIR’s Sparse Tensor Data Structure Overview

MLIR’s sparse tensors are implemented as structs with several array and vector attributes used to store the tensor’s elements. The source code for the struct representing MLIR’s sparse tensor can be found here.

The JIT engine provides mlir_graphblas.sparse_utils.MLIRSparseTensor, a wrapper around MLIR’s sparse tensor struct.


# The sparse tensor below looks like this (where the underscores represent zeros):
#
# [[1.2, ___, ___, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, 3.4, ___, ___, ___, ___, ___, ___],
#  [___, ___, 5.6, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, 7.8, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, ___, ___, ___],
#  [___, ___, ___, ___, ___, ___, ___, ___, ___, 9.0]]
#

indices = np.array([
    [0, 0],
    [1, 3],
    [2, 2],
    [4, 7],
    [9, 9],
], dtype=np.uint64) # Coordinates
values = np.array([1.2, 3.4, 5.6, 7.8, 9.0], dtype=np.float32) # values at each coordinate
sizes = np.array([10, 10], dtype=np.uint64) # tensor shape
sparsity = np.array([True, True], dtype=np.bool8) # a boolean for each dimension telling which dimensions are sparse

sparse_tensor = mlir_graphblas.sparse_utils.MLIRSparseTensor(indices, values, sizes, sparsity)

To initialize an instance of mlir_graphblas.sparse_utils.MLIRSparseTensor, we need to provide:

  • The coordinates of each non-zero position in the sparse tensor (see the variable indices above).

  • The values at each position (see the variable values above). There’s a one-to-one correspondence between each coordinate and each value (order matters here).

  • The shape of the sparse tensor (see the variable sizes above).

  • The sparsity of each dimension (see the variable sparsity above). This determines the sparsity/data layout, e.g. a matrix dense in the 0th dimension and sparse in the second dimension has a CSR data layout. For more information on how the sparse data layouts work, see the MLIR discussion on sparse tensors.

Despite the fact that we give the positions and values of the non-zero elements to the constructor in a way that resembles COO format, the underlying data structure does not store them in COO format. The sparsity of each dimension (see the variable sparsity above) is what the constructor uses to determine how to store the data.

Using MLIR’s Sparse Tensor Data Structure in MLIR Code

We’ll now go over how we can use the MLIR’s sparse tensor in some MLIR code.

Here’s the MLIR code for multiplying a sparse matrix with a dense tensor.


mlir_text = """
#trait_matvec = {
  indexing_maps = [
    affine_map<(i,j) -> (i,j)>,
    affine_map<(i,j) -> (j)>,
    affine_map<(i,j) -> (i)>
  ],
  iterator_types = ["parallel", "reduction"],
  sparse = [
    [ "S", "S" ],
    [ "D" ],
    [ "D" ]
  ],
  sparse_dim_map = [
    affine_map<(i,j) -> (j,i)>,
    affine_map<(i)   -> (i)>,
    affine_map<(i)   -> (i)>
  ]
}

#HyperSparseMatrix = #sparse_tensor.encoding<{
  dimLevelType = [ "compressed", "compressed" ],
  dimOrdering = affine_map<(i,j) -> (i,j)>,
  pointerBitWidth = 64,
  indexBitWidth = 64
}>

func @spmv(%arga: tensor<10x10xf32, #HyperSparseMatrix>, %argb: tensor<10xf32>) -> tensor<10xf32> {
  %output_storage = linalg.init_tensor [10] : tensor<10xf32>
  %0 = linalg.generic #trait_matvec
    ins(%arga, %argb : tensor<10x10xf32, #HyperSparseMatrix>, tensor<10xf32>)
    outs(%output_storage: tensor<10xf32>) {
      ^bb(%A: f32, %b: f32, %x: f32):
        %0 = arith.mulf %A, %b : f32
        %1 = arith.addf %x, %0 : f32
        linalg.yield %1 : f32
    } -> tensor<10xf32>
  return %0 : tensor<10xf32>
}
"""

One thing to note about the trait #trait_matvec used here that makes it different from the traits used by our dense operations we’ve shown in previous tutorials is that it specifies the sparsity via the sparse attribute. Note the presence of [ "S", "S" ]. This must correspond to the sparsity of our sparse tensor (see the Python variable sparsity from earlier).

Also, note the type of our sparse tensor. The type is !SparseTensor, which is an MLIR alias for the type !llvm.ptr<i8> from the LLVM dialect. MLIR’s passes for sparse tensors are currently under development and treat pointers to 8-bit integers as pointers to a sparse tensor struct. MLIR’s sparse tensor passes are able to differentiate normal uses of pointers to 8-bit integers from pointers to a sparse tensor struct via the use of the linalg.sparse_tensor operation. Only the results of linalg.sparse_tensor are treated as sparse tensors. This is a likely a temporary measure implemented as a prototype that is expected to change into a more mature piece of functionality in the upcoming months.

The results from linalg.sparse_tensor operations can be treated as normal tensors with all the complexities of indexing into the sparse tensor handled by MLIR’s sparse tensor passes.

The MLIR sparse tensor pass that we’ll use to lower our sparse tensors is --test-sparsification=lower. Here are all the passes we’ll use.


passes = [
    "--sparsification",
    "--sparse-tensor-conversion",
    "--linalg-bufferize",
    "--func-bufferize",
    "--tensor-bufferize",
    "--finalizing-bufferize",
    "--convert-linalg-to-loops",
    "--convert-scf-to-cf",
    "--convert-memref-to-llvm",
    "--convert-math-to-llvm",
    "--convert-openmp-to-llvm",
    "--convert-arith-to-llvm",
    "--convert-math-to-llvm",
    "--convert-std-to-llvm",
    "--reconcile-unrealized-casts"
]

SpMV Compilation

Let’s now actually see what our MLIR code can do.

We’ll first compile our code.


engine.add(mlir_text, passes)
spmv = engine.spmv

We already have a 10x10 sparse tensor from earlier (see the Python variable sparse_tensor) that we can use as an input. Let’s create a dense vector we can multiply it by.


dense_vector = np.arange(10, dtype=np.float32)

Let’s perform the calculation.


spmv_answer = spmv(sparse_tensor, dense_vector)
spmv_answer

array([ 0.      , 10.200001, 11.2     ,  0.      , 54.600002,  0.      ,
        0.      ,  0.      ,  0.      , 81.      ], dtype=float32)

Let’s verify if this is the result we expect.


dense_tensor = np.array([
 [1.2, 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 3.4, 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 5.6, 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 7.8, 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  ],
 [0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 0  , 9.0]
], dtype=np.float32)
np_answer = dense_tensor @ dense_vector

all(spmv_answer == np_answer)

True