This page was generated from tools/engine/scalar_plus_scalar.ipynb.
JIT Engine: Scalar + Scalar¶
This example will go over how to compile MLIR code to a function callable from Python.
The example MLIR code we’ll use here performs scalar addition.
Let’s first import some necessary modules and generate an instance of our JIT engine.
import mlir_graphblas
import numpy as np
engine = mlir_graphblas.MlirJitEngine()
Using development graphblas-opt: /Users/pnguyen/code/mlir-graphblas/mlir_graphblas/src/build/bin/graphblas-opt
Here’s some MLIR code to add two 32-bit floating point numbers.
mlir_text = r"""
func @scalar_add_f32(%a: f32, %b: f32) -> f32 {
%ans = arith.addf %a, %b : f32
return %ans : f32
}
"""
Let’s say we wanted to optimize our code with the following MLIR passes:
passes = [
"--linalg-bufferize",
"--func-bufferize",
"--tensor-bufferize",
"--finalizing-bufferize",
"--convert-linalg-to-loops",
"--convert-scf-to-cf",
"--convert-arith-to-llvm",
"--convert-math-to-llvm",
"--convert-std-to-llvm",
]
We can compile the MLIR code using our JIT engine.
engine.add(mlir_text, passes)
['scalar_add_f32']
The returned value above is a list of the names of all functions compiled in the given MLIR code.
We can access the compiled Python callables in two ways:
func_1 = engine['scalar_add_f32']
func_2 = engine.scalar_add_f32
They both point to the same function:
func_1 is func_2
True
We can call our function in Python:
scalar_add_f32 = engine.scalar_add_f32
scalar_add_f32(100.0, 200.0)
300.0
Let’s try creating a function to add two 8-bit integers.
mlir_text = r"""
func @scalar_add_i8(%a: i8, %b: i8) -> i8 {
%ans = arith.addi %a, %b : i8
return %ans : i8
}
"""
engine.add(mlir_text, passes)
scalar_add_i8 = engine.scalar_add_i8
Let’s verify that it works.
scalar_add_i8(30, 40)
70
What happens if we give invalid inputs, e.g. integers too large to fit into 8-bits?
scalar_add_i8(9999, 9999)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 scalar_add_i8(9999, 9999)
File ~/code/mlir-graphblas/mlir_graphblas/engine.py:843, in MlirJitEngine._generate_zero_or_single_valued_functions.<locals>.python_callable(mlir_function, encoders, c_callable, decoder, *args)
839 raise ValueError(
840 f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
841 )
842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
--> 843 encoded_args = sum(encoded_args, [])
844 encoded_result = c_callable(*encoded_args)
845 result = decoder(encoded_result)
File ~/code/mlir-graphblas/mlir_graphblas/engine.py:842, in <genexpr>(.0)
838 if len(args) != len(mlir_function.args):
839 raise ValueError(
840 f"{name} expected {len(mlir_function.args)} args but got {len(args)}."
841 )
--> 842 encoded_args = (encoder(arg) for arg, encoder in zip(args, encoders))
843 encoded_args = sum(encoded_args, [])
844 encoded_result = c_callable(*encoded_args)
File ~/code/mlir-graphblas/mlir_graphblas/engine.py:484, in input_scalar_to_ctypes.<locals>.encoder(arg)
482 can_cast = False
483 if not can_cast:
--> 484 raise TypeError(f"{repr(arg)} cannot be cast to {np_type}")
485 if not isinstance(arg, (np.number, int, float)):
486 raise TypeError(
487 f"{repr(arg)} is expected to be a scalar with dtype {np_type}"
488 )
TypeError: 9999 cannot be cast to <class 'numpy.int8'>
We get an exception! There’s some input and output type checking that takes place in compiled callables, so there’s some safety provided by the JIT Engine.