This page was generated from tools/engine/tensor_plus_tensor.ipynb.

JIT Engine: Tensor + Tensor

This example will go over how to compile MLIR code to a function callable from Python.

The example MLIR code we’ll use here performs element-wise tensor addition.

Let’s first import some necessary modules and generate an instance of our JIT engine.


import mlir_graphblas
import numpy as np

engine = mlir_graphblas.MlirJitEngine()
Using development graphblas-opt: /Users/pnguyen/code/mlir-graphblas/mlir_graphblas/src/build/bin/graphblas-opt

We’ll use the same set of passes to optimize and compile all of our examples below.


passes = [
    "--graphblas-structuralize",
    "--graphblas-optimize",
    "--graphblas-lower",
    "--sparsification",
    "--sparse-tensor-conversion",
    "--linalg-bufferize",
    "--func-bufferize",
    "--tensor-bufferize",
    "--finalizing-bufferize",
    "--convert-linalg-to-loops",
    "--convert-scf-to-cf",
    "--convert-memref-to-llvm",
    "--convert-math-to-llvm",
    "--convert-openmp-to-llvm",
    "--convert-arith-to-llvm",
    "--convert-math-to-llvm",
    "--convert-std-to-llvm",
    "--reconcile-unrealized-casts"
]

Fixed-Size Tensor Addition

Here’s some MLIR code to add two 32-bit floating point tensors of with the shape 2x3.


mlir_text = """
#trait_add = {
 indexing_maps = [
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (i, j)>
 ],
 iterator_types = ["parallel", "parallel"]
}

func @matrix_add_f32(%arga: tensor<2x3xf32>, %argb: tensor<2x3xf32>) -> tensor<2x3xf32> {
  %answer = linalg.generic #trait_add
    ins(%arga, %argb: tensor<2x3xf32>, tensor<2x3xf32>)
    outs(%arga: tensor<2x3xf32>) {
      ^bb(%a: f32, %b: f32, %s: f32):
        %sum = arith.addf %a, %b : f32
        linalg.yield %sum : f32
  } -> tensor<2x3xf32>
  return %answer : tensor<2x3xf32>
}
"""

Let’s compile our MLIR code.


engine.add(mlir_text, passes)

['matrix_add_f32']

Let’s try out our compiled function.


# grab our callable
matrix_add_f32 = engine.matrix_add_f32

# generate inputs
a = np.arange(6, dtype=np.float32).reshape([2, 3])
b = np.full([2, 3], 100, dtype=np.float32)

# generate output
result = matrix_add_f32(a, b)

result

array([[100., 101., 102.],
       [103., 104., 105.]], dtype=float32)

Let’s verify that our function works as expected.


np.all(result == np.add(a, b))

True

Arbitrary-Size Tensor Addition

The above example created a function to add two matrices of size 2x3. This function won’t work if we want to add two matrices of size 4x5 or any other size.


a = np.arange(20, dtype=np.float32).reshape([4, 5])
b = np.full([4, 5], 100, dtype=np.float32)
matrix_add_f32(a, b)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [8], in <cell line: 3>()
      1 a = np.arange(20, dtype=np.float32).reshape([4, 5])
      2 b = np.full([4, 5], 100, dtype=np.float32)
----> 3 matrix_add_f32(a, b)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:843, in MlirJitEngine._generate_zero_or_single_valued_functions.<locals>.python_callable(mlir_function, encoders, c_callable, decoder, *args)
    839     raise ValueError(
    840         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841     )
    842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
--> 843 encoded_args = sum(encoded_args, [])
    844 encoded_result = c_callable(*encoded_args)
    845 result = decoder(encoded_result)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:842, in <genexpr>(.0)
    838 if len(args) != len(mlir_function.args):
    839     raise ValueError(
    840         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841     )
--> 842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
    843 encoded_args = sum(encoded_args, [])
    844 encoded_result = c_callable(*encoded_args)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:400, in input_tensor_to_ctypes.<locals>.encoder(arg)
    395 if not len(dimensions) == len(arg.shape):
    396     raise ValueError(
    397         f"{repr(arg)} is expected to have rank {len(dimensions)} but has rank {len(arg.shape)}."
    398     )
--> 400 validate_arg_shape(arg)
    402 encoded_args = [arg, arg, 0]
    403 encoded_args += list(arg.shape)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:310, in input_tensor_to_ctypes.<locals>.validate_arg_shape(arg)
    305     expected_dim_size = dimensions[dim_index]
    306     if (
    307         expected_dim_size is not None
    308         and arg.shape[dim_index] != expected_dim_size
    309     ):
--> 310         raise ValueError(
    311             f"{repr(arg)} is expected to have size {expected_dim_size} in the "
    312             f"{dim_index}th dimension but has size {arg.shape[dim_index]}."
    313         )
    314 return

ValueError: array([[ 0.,  1.,  2.,  3.,  4.],
       [ 5.,  6.,  7.,  8.,  9.],
       [10., 11., 12., 13., 14.],
       [15., 16., 17., 18., 19.]], dtype=float32) is expected to have size 2 in the 0th dimension but has size 4.

While it’s nice that the JIT engine is able to detect that there’s a size mismatch, it’d be nicer to have a function that can add two tensors of arbitrary size.

We’ll now show how to create such a function for matrix of 32-bit integers.


mlir_text = """
#trait_add = {
 indexing_maps = [
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (i, j)>
 ],
 iterator_types = ["parallel", "parallel"]
}

func @matrix_add_i32(%arga: tensor<?x?xi32>, %argb: tensor<?x?xi32>) -> tensor<?x?xi32> {
  // Find the max dimensions of both args
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %arga_dim0 = tensor.dim %arga, %c0 : tensor<?x?xi32>
  %arga_dim1 = tensor.dim %arga, %c1 : tensor<?x?xi32>
  %argb_dim0 = tensor.dim %argb, %c0 : tensor<?x?xi32>
  %argb_dim1 = tensor.dim %argb, %c1 : tensor<?x?xi32>
  %dim0_gt = arith.cmpi "ugt", %arga_dim0, %argb_dim0 : index
  %dim1_gt = arith.cmpi "ugt", %arga_dim1, %argb_dim1 : index
  %output_dim0 = arith.select %dim0_gt, %arga_dim0, %argb_dim0 : index
  %output_dim1 = arith.select %dim1_gt, %arga_dim1, %argb_dim1 : index
  %output_tensor = linalg.init_tensor [%output_dim0, %output_dim1] : tensor<?x?xi32>

  // Perform addition
  %answer = linalg.generic #trait_add
    ins(%arga, %argb: tensor<?x?xi32>, tensor<?x?xi32>)
    outs(%output_tensor: tensor<?x?xi32>) {
      ^bb(%a: i32, %b: i32, %s: i32):
        %sum = arith.addi %a, %b : i32
        linalg.yield %sum : i32
    } -> tensor<?x?xi32>
 return %answer : tensor<?x?xi32>
}
"""

The compilation of this MLIR code will be the same as our first example. The main difference is in how we wrote our MLIR code (notice the use of “?x?” when denoting the shapes of tensors).


# compile
engine.add(mlir_text, passes)
matrix_add_i32 = engine.matrix_add_i32

# generate inputs
a = np.arange(20, dtype=np.int32).reshape([4, 5])
b = np.full([4, 5], 100, dtype=np.int32)

# generate output
result = matrix_add_i32(a, b)

result

array([[100, 101, 102, 103, 104],
       [105, 106, 107, 108, 109],
       [110, 111, 112, 113, 114],
       [115, 116, 117, 118, 119]], dtype=int32)

assert np.all(result == np.add(a, b))

Note that we get some level of safety regarding the tensor types as we get an exception if we pass in tensors with the wrong dtype.


matrix_add_i32(a, b.astype(np.int64))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [13], in <cell line: 1>()
----> 1 matrix_add_i32(a, b.astype(np.int64))

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:843, in MlirJitEngine._generate_zero_or_single_valued_functions.<locals>.python_callable(mlir_function, encoders, c_callable, decoder, *args)
    839     raise ValueError(
    840         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841     )
    842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
--> 843 encoded_args = sum(encoded_args, [])
    844 encoded_result = c_callable(*encoded_args)
    845 result = decoder(encoded_result)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:842, in <genexpr>(.0)
    838 if len(args) != len(mlir_function.args):
    839     raise ValueError(
    840         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841     )
--> 842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
    843 encoded_args = sum(encoded_args, [])
    844 encoded_result = c_callable(*encoded_args)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:392, in input_tensor_to_ctypes.<locals>.encoder(arg)
    388     raise TypeError(
    389         f"{repr(arg)} is expected to be an instance of {np.ndarray.__qualname__}"
    390     )
    391 if not arg.dtype == element_np_type:
--> 392     raise TypeError(
    393         f"{repr(arg)} is expected to have dtype {element_np_type}"
    394     )
    395 if not len(dimensions) == len(arg.shape):
    396     raise ValueError(
    397         f"{repr(arg)} is expected to have rank {len(dimensions)} but has rank {len(arg.shape)}."
    398     )

TypeError: array([[100, 100, 100, 100, 100],
       [100, 100, 100, 100, 100],
       [100, 100, 100, 100, 100],
       [100, 100, 100, 100, 100]]) is expected to have dtype <class 'numpy.int32'>

Note that in the MLIR code, each of our output tensor’s dimensions is the max of each dimension of our inputs.

A consequence of this is that our function doesn’t enforce that our inputs are the same shape.


# generate differently shaped inputs
a = np.arange(6, dtype=np.int32).reshape([2, 3])
b = np.full([4, 5], 100, dtype=np.int32)

# generate output
result = matrix_add_i32(a, b)



result.shape

(4, 5)

result

array([[       100,        101,        102, -536870912,          7],
       [       103,        104,        105,          0,         48],
       [1852990827,  808348773,  862337379,  758342450, 1667588407],
       [ 879047725,  809053497, 1680696121, 1650798691,  878994488]],
      dtype=int32)

This result is somewhat unexpected. The weird numbers we see (the zeros and large numbers) are come from the garbage/uninitialized values in the memory for our output (i.e. %output_memref).

This is an implementation problem with how we wrote our MLIR code as there’s no enforcement of the need for both inputs to be the same shape. Special care must be taken when dealing with arbitrary sized tensors or else we might get bugs or unexpected results as shown here.