This page was generated from tools/engine/scalar_plus_scalar.ipynb.

JIT Engine: Scalar + Scalar

This example will go over how to compile MLIR code to a function callable from Python.

The example MLIR code we’ll use here performs scalar addition.

Let’s first import some necessary modules and generate an instance of our JIT engine.


import mlir_graphblas
import numpy as np

engine = mlir_graphblas.MlirJitEngine()

Here’s some MLIR code to add two 32-bit floating point numbers.


mlir_text = r"""
func @scalar_add_f32(%a: f32, %b: f32) -> f32 {
  %ans = arith.addf %a, %b : f32
  return %ans : f32
}
"""

Let’s say we wanted to optimize our code with the following MLIR passes:


passes = [
    "--linalg-bufferize",
    "--func-bufferize",
    "--tensor-bufferize",
    "--tensor-constant-bufferize",
    "--finalizing-bufferize",
    "--convert-linalg-to-loops",
    "--convert-scf-to-std",
    "--convert-arith-to-llvm",
    "--convert-math-to-llvm",
    "--convert-std-to-llvm",
]

We can compile the MLIR code using our JIT engine.


engine.add(mlir_text, passes)

['scalar_add_f32']

The returned value above is a list of the names of all functions compiled in the given MLIR code.

We can access the compiled Python callables in two ways:


func_1 = engine['scalar_add_f32']
func_2 = engine.scalar_add_f32

They both point to the same function:


func_1 is func_2

True

We can call our function in Python:


scalar_add_f32 = engine.scalar_add_f32
scalar_add_f32(100.0, 200.0)

300.0

Let’s try creating a function to add two 8-bit integers.


mlir_text = r"""
func @scalar_add_i8(%a: i8, %b: i8) -> i8 {
  %ans = arith.addi %a, %b : i8
  return %ans : i8
}
"""
engine.add(mlir_text, passes)
scalar_add_i8 = engine.scalar_add_i8

Let’s verify that it works.


scalar_add_i8(30, 40)

70

What happens if we give invalid inputs, e.g. integers too large to fit into 8-bits?


scalar_add_i8(9999, 9999)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/yj/nmf5xtns3hx6qgdnybj964q80000gp/T/ipykernel_63061/1410601357.py in <module>
----> 1 scalar_add_i8(9999, 9999)

~/code/mlir-graphblas/mlir_graphblas/engine.py in python_callable(mlir_function, encoders, c_callable, decoder, *args)
    841                     )
    842                 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
--> 843                 encoded_args = sum(encoded_args, [])
    844                 encoded_result = c_callable(*encoded_args)
    845                 result = decoder(encoded_result)

~/code/mlir-graphblas/mlir_graphblas/engine.py in <genexpr>(.0)
    840                         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841                     )
--> 842                 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
    843                 encoded_args = sum(encoded_args, [])
    844                 encoded_result = c_callable(*encoded_args)

~/code/mlir-graphblas/mlir_graphblas/engine.py in encoder(arg)
    482             can_cast = False
    483         if not can_cast:
--> 484             raise TypeError(f"{repr(arg)} cannot be cast to {np_type}")
    485         if not isinstance(arg, (np.number, int, float)):
    486             raise TypeError(

TypeError: 9999 cannot be cast to <class 'numpy.int8'>

We get an exception! There’s some input and output type checking that takes place in compiled callables, so there’s some safety provided by the JIT Engine.