This page was generated from tools/engine/scalar_plus_scalar.ipynb.

JIT Engine: Scalar + Scalar

This example will go over how to compile MLIR code to a function callable from Python.

The example MLIR code we’ll use here performs scalar addition.

Let’s first import some necessary modules and generate an instance of our JIT engine.


import mlir_graphblas
import numpy as np

engine = mlir_graphblas.MlirJitEngine()
Using development graphblas-opt: /Users/pnguyen/code/mlir-graphblas/mlir_graphblas/src/build/bin/graphblas-opt

Here’s some MLIR code to add two 32-bit floating point numbers.


mlir_text = r"""
func @scalar_add_f32(%a: f32, %b: f32) -> f32 {
  %ans = arith.addf %a, %b : f32
  return %ans : f32
}
"""

Let’s say we wanted to optimize our code with the following MLIR passes:


passes = [
    "--linalg-bufferize",
    "--func-bufferize",
    "--tensor-bufferize",
    "--finalizing-bufferize",
    "--convert-linalg-to-loops",
    "--convert-scf-to-cf",
    "--convert-arith-to-llvm",
    "--convert-math-to-llvm",
    "--convert-std-to-llvm",
]

We can compile the MLIR code using our JIT engine.


engine.add(mlir_text, passes)

['scalar_add_f32']

The returned value above is a list of the names of all functions compiled in the given MLIR code.

We can access the compiled Python callables in two ways:


func_1 = engine['scalar_add_f32']
func_2 = engine.scalar_add_f32

They both point to the same function:


func_1 is func_2

True

We can call our function in Python:


scalar_add_f32 = engine.scalar_add_f32
scalar_add_f32(100.0, 200.0)

300.0

Let’s try creating a function to add two 8-bit integers.


mlir_text = r"""
func @scalar_add_i8(%a: i8, %b: i8) -> i8 {
  %ans = arith.addi %a, %b : i8
  return %ans : i8
}
"""
engine.add(mlir_text, passes)
scalar_add_i8 = engine.scalar_add_i8

Let’s verify that it works.


scalar_add_i8(30, 40)

70

What happens if we give invalid inputs, e.g. integers too large to fit into 8-bits?


scalar_add_i8(9999, 9999)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 scalar_add_i8(9999, 9999)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:843, in MlirJitEngine._generate_zero_or_single_valued_functions.<locals>.python_callable(mlir_function, encoders, c_callable, decoder, *args)
    839     raise ValueError(
    840         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841     )
    842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
--> 843 encoded_args = sum(encoded_args, [])
    844 encoded_result = c_callable(*encoded_args)
    845 result = decoder(encoded_result)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:842, in <genexpr>(.0)
    838 if len(args) != len(mlir_function.args):
    839     raise ValueError(
    840         f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
    841     )
--> 842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
    843 encoded_args = sum(encoded_args, [])
    844 encoded_result = c_callable(*encoded_args)

File ~/code/mlir-graphblas/mlir_graphblas/engine.py:484, in input_scalar_to_ctypes.<locals>.encoder(arg)
    482     can_cast = False
    483 if not can_cast:
--> 484     raise TypeError(f"{repr(arg)} cannot be cast to {np_type}")
    485 if not isinstance(arg, (np.number, int, float)):
    486     raise TypeError(
    487         f"{repr(arg)} is expected to be a scalar with dtype {np_type}"
    488     )

TypeError: 9999 cannot be cast to <class 'numpy.int8'>

We get an exception! There’s some input and output type checking that takes place in compiled callables, so there’s some safety provided by the JIT Engine.