This page was generated from tools/engine/scalar_times_tensor.ipynb.

JIT Engine: Scalar x Tensor

This example will go over how to compile MLIR code for multiplying a scalar by a tensor.

Previous tutorials have gone over how to broadcast vectors. For the simple task of multiplying a each tensor’s elements by a scalar, broadcasting may be unwarranted or unnecessary. We’ll go over how to implement this in a much simpler and more straightforward fashion.

Let’s first import some necessary modules and generate an instance of our JIT engine.


import mlir_graphblas
import numpy as np

engine = mlir_graphblas.MlirJitEngine()
Using development graphblas-opt: /Users/pnguyen/code/mlir-graphblas/mlir_graphblas/src/build/bin/graphblas-opt

Here’s the MLIR code we’ll use.


mlir_text = """
#trait_add = {
 indexing_maps = [
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (i, j)>
 ],
 iterator_types = ["parallel", "parallel"]
}

func @scale(%arg_tensor: tensor<2x3xf32>, %arg_scalar: f32) -> tensor<2x3xf32> {
  %output_storage = arith.constant dense<0.0> : tensor<2x3xf32>
  %answer = linalg.generic #trait_add
    ins(%arg_tensor: tensor<2x3xf32>)
    outs(%arg_tensor: tensor<2x3xf32>) {
      ^bb(%a: f32, %s: f32):
        %scaled = arith.mulf %a, %arg_scalar : f32
        linalg.yield %scaled : f32
    } -> tensor<2x3xf32>
 return %answer : tensor<2x3xf32>
}
"""

These are the passes we’ll utilize.


passes = [
    "--graphblas-structuralize",
    "--graphblas-optimize",
    "--graphblas-lower",
    "--sparsification",
    "--sparse-tensor-conversion",
    "--linalg-bufferize",
    "--func-bufferize",
    "--tensor-bufferize",
    "--finalizing-bufferize",
    "--convert-linalg-to-loops",
    "--convert-scf-to-cf",
    "--convert-memref-to-llvm",
    "--convert-math-to-llvm",
    "--convert-openmp-to-llvm",
    "--convert-arith-to-llvm",
    "--convert-math-to-llvm",
    "--convert-std-to-llvm",
    "--reconcile-unrealized-casts"
]

Let’s compile our MLIR code.


engine.add(mlir_text, passes)

['scale']

Let’s try out our compiled function.


# grab our callable
scale = engine.scale

# generate inputs
a = np.arange(6, dtype=np.float32).reshape([2,3])

# generate output
result = scale(a, 100)

result

array([[  0., 100., 200.],
       [300., 400., 500.]], dtype=float32)

Let’s verify that our function works as expected.


np.all(result == a*100)

True