GraphBLAS Dialect Op Reference

The graphblas dialect describes standard sparse tensor operations that are found in the GraphBLAS spec. The ops are not one-to-one equivalents of GraphBLAS function calls in order to fit into MLIR’s SSA requirements.

This document is not intended to be a tutorial and acts more as a reference manual for the ops in the GraphBLAS dialect. Tutorials can be found in later sections of our documentation.

Assumptions

Although the sparse tensor encoding in MLIR is extremely flexible, the graphblas dialect and associated lowering pass only supports three encodings currently.

The CSR64 encoding is usually defined with the alias:

#CSR64 = #sparse_tensor.encoding<{
  dimLevelType = [ "dense", "compressed" ],
  dimOrdering = affine_map<(i,j) -> (i,j)>,
  pointerBitWidth = 64,
  indexBitWidth = 64
}>

The CSC64 encoding can be defined with the alias:

#CSC64 = #sparse_tensor.encoding<{
  dimLevelType = [ "dense", "compressed" ],
  dimOrdering = affine_map<(i,j) -> (j,i)>,
  pointerBitWidth = 64,
  indexBitWidth = 64
}>

In terms of data structure contents CSR and CSC are identical (with index, pointer, and value arrays), just the indexing is reversed for CSC. The sparse tensor is then defined in the same way as a regular MLIR tensor, but with this additional encoding attribute:

tensor<?x?xf64, #CSC64>

The CV64 encoding (for sparse vectors) is usually defined with the alias:

#CV64 = #sparse_tensor.encoding<{
  dimLevelType = [ "compressed" ],
  pointerBitWidth = 64,
  indexBitWidth = 64
}>

Note that the --graphblas-lower pass only supports tensors with unknown dimensions (indicated by the ?).

Operation Definitions

graphblas.apply_generic

Generic tensor apply operation.

Syntax:

operation ::= `graphblas.apply_generic` $input attr-dict `:` type($input) `to` type($output) $extensions

Applies an arbitrary transformation to every element of a CSR or CSC matrix or sparse vector according to the given transformation block. Only one transformation block is allowed. If the boolean in_place attribute is set to true, then the changes will happen in place and the input and output tensors will be the same identical tensor. The in_place attribute has a default value of false.

Example:

%answer = graphblas.apply_generic %m : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSR64> {
  ^bb0(%val: f64):
    %result = arith.negf %val : f64
    graphblas.yield transform_out %result : f64
}
%thunk = constant 0.0 : f64
%answer = graphblas.apply_generic %sparse_tensor : tensor<?xf64, #CV64> to tensor<?xf64, #CV64> {
  ^bb0(%val: f64):
    %pick = cmpf olt, %val, %thunk : f64
    %result = select %pick, %val, %thunk : f64
    graphblas.yield transform_out %result : f64
}

Attributes:

Attribute

MLIR Type

Description

in_place

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.apply

Tensor element-wise apply operation.

Syntax:

operation ::= `graphblas.apply` $left (`,` $right^)? attr-dict `:` `(` type($left) (`,` type($right)^)? `)` `to` type($output)

Applies in an element-wise fashion the function indicated by the apply_operator attribute to each element of the given sparse tensor. The operator can be unary or binary. Binary operators require a thunk. Most unary and binary operators are supported, along with the custom “identity” apply operator.

If the boolean in_place attribute is true, then the changes will update the input tensor in place. The in_place attribute has a default value of false.

Note that using the “identity” operator does not create a copy of the input tensor.

The given sparse tensor must either be a CSR matrix, CSC matrix, or a sparse vector.

Some binary operators, e.g. “div”, are not symmetric. The sparse tensor and thunk should be given in the order they should be given to the binary operator. For example, to divide every element of a matrix by 2, use the following:

%thunk = constant 2 : i64
%matrix_answer = graphblas.apply %sparse_matrix, %thunk { apply_operator = "div" } : (tensor<?x?xi64, #CSR64>, i64) to tensor<?x?xi64, #CSR64>

As another example, to divide 10 by each element of a sparse vector, use the following:

%thunk = constant 10 : i64
%vector_answer = graphblas.apply %thunk, %sparse_vector { apply_operator = "div" } : (i64, tensor<?xi64, #CV64>) to tensor<?xi64, #CV64>

Note that the application only takes place for elements that are present in the matrix. Thus, the operation will not apply when the values are missing in the tensor. For example, 1.0 / [ _ , 2.0, _ ] == [ _ , 0.5, _ ].

The shape of the output tensor will match that of the input tensor.

Example:

%thunk = constant 100 : i64
%matrix_answer = graphblas.apply %sparse_matrix, %thunk { apply_operator = "min" } : (tensor<?x?xi64, #CSR64>, i64) to tensor<?x?xi64, #CSR64>
%vector_answer = graphblas.apply %sparse_vector { apply_operator = "abs" } : (tensor<?xi64, #CV64>) to tensor<?xi64, #CV64>

Attributes:

Attribute

MLIR Type

Description

apply_operator

::mlir::StringAttr

string attribute

in_place

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

left

any type

right

any type

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.cast

Changes graph storage parameters: dtype and bitwidths.

Syntax:

operation ::= `graphblas.cast` $input attr-dict `:` type($input) `to` type($output)

Rewrite the contents of a sparse tensor to use a new dtype or a new pointer or index bitwidth. Layout changes (ex. CSR->CSC) are not supported by this operation. Use convert_layout instead.

Example:

%a_int = graphblas.cast %a : tensor<?x?xf64, #CSR64> to tensor<?x?xi32, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.comment

Comment operation.

Syntax:

operation ::= `graphblas.comment` attr-dict

graphblas.comment is intended to be a no-op and returns zero values. It merely contains a string attribute intended to hold code comments.

Example:

graphblas.comment { comment = "here is a comment!" }

Attributes:

Attribute

MLIR Type

Description

comment

::mlir::StringAttr

string attribute

graphblas.convert_layout

Converts graph storage layout.

Syntax:

operation ::= `graphblas.convert_layout` $input attr-dict `:` type($input) `to` type($output)

Rewrite the contents of a sparse tensor to change it from CSR to CSC, or vice versa.

Vacuous conversions (e.g. CSC → CSC or CSR → CSR) are equivalent to no-ops and are removed by the --graphblas-lower pass.

Example:

%answer = graphblas.convert_layout %sparse_tensor : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSC64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

2D tensor of any type values

Results:

Result

Description

output

2D tensor of any type values

graphblas.diag

Diagonal operation.

Syntax:

operation ::= `graphblas.diag` $input attr-dict `:` type($input) `to` type($output)

When given a square CSR or CSC matrix, returns the diagonal as a sparse vector. When given a sparse vector, returns a square CSR or CSC matrix with the vector’s values along the diagonal.

Example:

%csr_matrix_answer = graphblas.diag %vec : tensor<?xi64, #CV64> to tensor<?x?xi64, #CSR64>
%csc_matrix_answer = graphblas.diag %vec : tensor<?xi64, #CV64> to tensor<?x?xi64, #CSC64>
%vector_answer = graphblas.diag %mat : tensor<?x?xi64, #CSR64> to tensor<?xi64, #CV64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.dup

Returns a duplicate of the input sparse tensor.

Syntax:

operation ::= `graphblas.dup` $input attr-dict `:` type($input)

Returns a duplicate copy of the input sparse tensor.

Example:

%B = graphblas.dup %A : tensor<?x?xf64, #CSR64>
%new_vec = graphblas.dup %vec : tensor<?xf64, #CV64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.equal

Equality checking operation for vectors and matrices.

Syntax:

operation ::= `graphblas.equal` $a `,` $b attr-dict `:` type($a) `,` type($b)

Performs an equality check. The given tensors must be sparse vectors, CSR matrices, or CSC matrices. Checks equality of rank and size of tensors, as well as values and structure. Returns a single boolean value.

Example:

%answer = graphblas.equal %vec, %other_vec : tensor<?xi64, #CV64>, tensor<?xi64, #CV64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

Results:

Result

Description

output

1-bit signless integer

graphblas.from_coo

Build new sparse tensor from coordinate tensors.

Syntax:

operation ::= `graphblas.from_coo` $indices `,` $values `[` $sizes `]` attr-dict `:` type($indices) `,` type($values) `to` type($output)

Builds a new sparse tensor using two dense tensors representing the indices as coordinates and associated values. The indices must be sorted or things break badly.

%v = graphblas.from_coo %indices, %vals : tensor<?xindex>, tensor<?xf64> to tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

indices

2D tensor of index values

values

1D tensor of any type values

sizes

index

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.intersect_generic

Element-wise intersection operation.

Syntax:

operation ::= `graphblas.intersect_generic` $a `,` $b attr-dict `:` `(` type($a) `,` type($b)  `)` `to` type($output) $extensions

Performs an element-wise intersection between two matrices or two vectors. The resulting sparse structure will be the union of the two input structures. When both objects have an overlapping element in a cell, an operation combines the result according to the given binary operator.

Example:

%combined = graphblas.intersect_generic %A, %B : (
  tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64) to tensor<?x?xf64, #CSR64> {
  ^bb0(%a : f64, %b : f64):
    %result = arith.addf %a, %b : f64
    graphblas.yield mult %result : f64
}

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.intersect

Element-wise intersection operation.

Syntax:

operation ::= `graphblas.intersect` $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output)

Performs an element-wise intersection between two matrices or two vectors. The resulting sparse tensor will be the intersection of the two input structures. When both objects have an overlapping element in a cell, an operation combines the result according to the given operator. All binary operators are supported.

The mask (if provided) must be the same format as the returned object. There’s an optional boolean mask_complement attribute (which has a default value of false) that will use the structural complement of the mask.

Example:

%combined = graphblas.intersect %A, %B { intersect_operator = "mult" } : (
    tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64>) to tensor<?x?xf64, #CSR64>

%combined = graphblas.intersect %A, %B, %mask { intersect_operator = "mult", mask_complement = true } : (
    tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64>, tensor<?x?xi32, #CSR64>) to tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

intersect_operator

::mlir::StringAttr

string attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.matrix_multiply_generic

Generic matrix multiply operation with an optional structural mask.

Syntax:

operation ::= `graphblas.matrix_multiply_generic` $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b)  (`,` type($mask)^)? `)` `to` type($output) $extensions

This op performs computations over the two sparse tensor inputs using the same access pattern as a conventional matrix multiply with the given blocks allowing us to modify the behavior.

This op takes as input 2 sparse tensor inputs and an optional structural mask.

Additionally, this op takes 3 required blocks (we’ll refer to them as the “mult”, “add”, and “add_identity” blocks) and 1 optional block (we’ll refer to it as the “transform_out” block).

In a conventional matrix multiply where the multiplication between two elements takes place, this op instead performs the behavior specified in the “mult” block. The “mult” block takes two scalar arguments and uses the graphblas.yield terminator op (with the “kind” attribute set to “mult”) to return the result of the element-wise computation.

In a conventional matrix multiply where the summation over the products from the element-wise multiplications take place, this op instead performs the behavior specified in the “add” block to aggregate the results. The “add” block takes two scalar arguments (the first representing the current aggregation and the second representing the next value to be aggregated) and uses the graphblas.yield terminator op (with the “kind” attribute set to “add”) to return the result of the current aggregation.

The aggregation taking place in the “add” block requires an initial value (for conventional matrix multiplication, this value is zero). Using the graphblas.yield terminator op (with the “kind” attribute set to “add_identity”) in the “add_identity” block let’s us specify this initial value. This block takes no arguments.

This op additionally takes an optional “transform_out” block that performs an element-wise transformation on the final aggregated values from the “add” block. The “transform_out” block takes one argument and returns one value via the graphblas.yield terminator op (with the “kind” attribute set to “transform_out”).

The mask (if provided) must be the same format as the returned object. A required boolean mask_complement attribute indicates whether to use the structural complement of the mask.

Example:

%answer = graphblas.matrix_multiply_generic %a, %b {mask_complement = false} : (tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSC64>) to tensor<?x?xf64, #CSR64> {
    ^bb0:
         %identity = constant 0.0 : f64
         graphblas.yield add_identity %identity : f64
  },{
    ^bb0(%add_a: f64, %add_b: f64):
         %add_result = arith.addf %add_a, %add_b : f64
         graphblas.yield add %add_result : f64
  },{
    ^bb0(%mult_a: f64, %mult_b: f64):
         %mult_result = arith.mulf %mult_a, %mult_b : f64
         graphblas.yield mult %mult_result : f64
  },{
     ^bb0(%value: f64):
         %result = arith.addf %value, %c100_f64: f64
         graphblas.yield transform_out %result : f64
  }

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

Results:

Result

Description

output

any type

graphblas.matrix_multiply

Matrix multiply operation with an optional structural mask.

Syntax:

operation ::= `graphblas.matrix_multiply` $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b)  (`,` type($mask)^)? `)` `to` type($output)

Performs a matrix multiply according to the given semiring and optional structural mask. The structural mask specifies which values in the output are to be computed and thus must have the same shape as the expected output.

The semiring must be a string of the form “_“, e.g. ”plus_times” or “min_plus”.

Matrix times vector will return a vector. Vector times matrix will return a vector. Matrix times matrix will return a CSR matrix. Vector times vector will return a scalar.

The mask (if provided) must be the same format as the returned object. There’s an optional boolean mask_complement attribute (which has a default value of false) that will use the structural complement of the mask.

It should be noted that masks are not allowed for vector times vector multiplication.

Examples:

%answer = graphblas.matrix_multiply %argA, %argB { semiring = "plus_plus" } : (tensor<?x?xi64, #CSR64>, tensor<?x?xi64, #CSC64>) to tensor<?x?xi64, #CSR64>
%answer = graphblas.matrix_multiply %argA, %argB, %mask { semiring = "min_times" } : (tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSC64>, tensor<?x?xf64, #CSR64>) to tensor<?x?xf64, #CSR64>
%answer = graphblas.matrix_multiply %mat, %vec, %mask { semiring = "any_first", mask_complement = true } : (tensor<?x?xf64, #CSR64>, tensor<?xf64, #CV64>, tensor<?xf64, #CV64>) to tensor<?xf64, #CV64>
%answer = graphblas.matrix_multiply %vec, %mat { semiring = "min_second" } : (tensor<?xf64, #CV64>, tensor<?x?xf64, #CSC64>) to tensor<?xf64, #CV64>
%answer = graphblas.matrix_multiply %vecA, %vecB { semiring = "any_pair" } : (tensor<?xf64, #CV64>, tensor<?xf64, #CV64>) to tensor<?xf64, #CV64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

semiring

::mlir::StringAttr

string attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

Results:

Result

Description

output

any type

graphblas.matrix_multiply_reduce_to_scalar_generic

Matrix multiply followed by reduction to a scalar with an optional structural mask.

Syntax:

operation ::= `graphblas.matrix_multiply_reduce_to_scalar_generic` $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b)  (`,` type($mask)^)? `)` `to` type($output) $extensions

Performs a matrix multiply followed by a reduction to scalar. Supports the same extension blocks as graphblas.matrix_multiply_generic and also requires a binary aggregation block and aggregation identity block. These latter two blocks are used for reducing the result of the matrix multiply to a scalar.

Example:

%answer = graphblas.matrix_multiply_reduce_to_scalar_generic %a, %b : (tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSC64>) to f64 {
    ^bb0:
        %identity = constant 0.0 : f64
        graphblas.yield add_identity %identity : f64
},{
    ^bb0(%add_a: f64, %add_b: f64):
        %add_result = arith.addf %add_a, %add_b : f64
        graphblas.yield add %add_result : f64
},{
    ^bb0(%mult_a: f64, %mult_b: f64):
        %mult_result = arith.mulf %mult_a, %mult_b : f64
        graphblas.yield mult %mult_result : f64
},{
    %agg_identity = constant 0.0 : f64
    graphblas.yield agg_identity %agg_identity : f64
},{
    ^bb0(%lhs: f64, %rhs: f64):
        %agg_result = arith.addf %lhs, %rhs: f64
        graphblas.yield agg %agg_result : f64
}

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

a

2D tensor of any type values

b

2D tensor of any type values

mask

2D tensor of any type values

Results:

Result

Description

output

any type

graphblas.matrix_select_random

Random selection of elements from each row of matrix.

Syntax:

operation ::= `graphblas.matrix_select_random` $input `,` $n `,` $rng_context attr-dict `:` `(` type($input) `,` type($n) `,` type($rng_context)  `)` `to` type($output)

Selects a random subset of up to n elements in each row of a CSR matrix. If there are less than n elements in a row, all elements in the row are included in the output.

An external function must be provided via the choose_n attribute to the op with the following signature:

func @my_choose_n(%context: !llvm.ptr<i8>,
                  %n: IndexType, %max_i: IndexType,
                  %output_indices: memref<?xIndexType>,
                  %row_values: memref<?xValueType)

where IndexType corresponds to the index element type of the sparse tensor input and ValueType corresponds to the value element type of the sparse tensor input. This external function selects n random indices from the interval [0, max_i) and writes them to the output_indices memref in increasing order. If desired, the distribution of selected indices can be biased by the values in row_values, which will have length max_i. A uniform choice function will ignore this last argument.

The implementation of the choose_n function is not specified by this op because it will differ significantly depending on use case (uniform or weighted sampling) and desired execution target (serial, parallel, GPU, etc).

Example:

%output = graphblas.matrix_select_random %a, %n, %rng_context { choose_n = @uniform_choose_n } : (tensor<?x?xf64, #CSR64>, i64, !llvm.ptr<i8>) to tensor<?x?xf64, #CSR64>

Attributes:

Attribute

MLIR Type

Description

choose_n

::mlir::SymbolRefAttr

symbol reference attribute

Operands:

Operand

Description

input

2D tensor of any type values

n

integer

rng_context

any type

Results:

Result

Description

output

2D tensor of any type values

graphblas.num_cols

Returns the number of columns in a matrix.

Syntax:

operation ::= `graphblas.num_cols` $input attr-dict `:` type($input)

Return the return the number of columns in a CSR or CSC matrix.

Example:

%ncols = graphblas.num_cols %sparse_matrix : tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

2D tensor of any type values

Results:

Result

Description

result

index

graphblas.num_rows

Returns the number of rows in a matrix.

Syntax:

operation ::= `graphblas.num_rows` $input attr-dict `:` type($input)

Return the return the number of rows in a CSR or CSC matrix.

Example:

%nrows = graphblas.num_rows %sparse_matrix : tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

2D tensor of any type values

Results:

Result

Description

result

index

graphblas.num_vals

Returns the number of values in a sparse tensor.

Syntax:

operation ::= `graphblas.num_vals` $input attr-dict `:` type($input)

Returns the number of values in a CSC matrix, CSR matrix, or sparse vector.

Example:

%csr_nnz = graphblas.num_vals %csr_matrix : tensor<?x?xf64, #CSR64>
%vector_nnz = graphblas.num_vals %sparse_vector : tensor<?xf64, #CV64>
%csc_nnz = graphblas.num_vals %csc_matrix : tensor<?x?xf64, #CSC64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

result

index

graphblas.print

Print operation.

Syntax:

operation ::= `graphblas.print` $values attr-dict `:` type($values)

graphblas.print is used to pretty print values to stdout. This is intended to be used for debugging only. The strings attribute is a list of strings. This op is variadic and takes an arbitrary number of inputs. The printing alternates between printing the strings and the input values.

%c9_9_f32 = constant 9.9 : f32
%c1_i32 = constant 1 : i32

// prints "start 9.9 middle 1 end ".
graphblas.print %c9_9_f32, %c1_i32 { strings = ["start ", " middle ", " end"] } : f32, i32

// prints "start 9.9 middle   end  z y x ".
graphblas.print %c9_9_f32 { strings = ["start ", " middle ", " end", " z", "y", "x"] } : f32

// prints "start 9.9 1 1 1 1".
graphblas.print %c9_9_f32, %c1_i32, %c1_i32, %c1_i32, %c1_i32 { strings = ["start "] } : f32, i32, i32, i32, i32

// prints "9.9".
graphblas.print %c9_9_f32 { strings = [] } : f32

%dense_vec_fixed = arith.constant dense<[0.0, 10.0, 20.0, 0.0]> : tensor<4xf64>
%dense_vec = tensor.cast %dense_vec_fixed : tensor<4xf64> to tensor<?xf64>
%vec = sparse_tensor.convert %dense_vec : tensor<?xf64> to tensor<?xf64, #CV64>

// prints "vec [0, 10, 20, 0]".
graphblas.print %dense_vec_fixed { strings = ["vec "] } : tensor<4xf64>

// prints "vec [0, 10, 20, 0]".
graphblas.print %dense_vec { strings = ["vec "] } : tensor<?xf64>

// prints "vec [_, 10, 20, _]".
graphblas.print %vec { strings = ["vec "] } : tensor<?xf64, #CV64>

%dense_mat_fixed = arith.constant dense<[
    [0.0, 1.0, 2.0, 0.0],
    [0.0, 0.0, 0.0, 3.0]
  ]> : tensor<2x4xf64>
%dense_mat = tensor.cast %dense_mat_fixed : tensor<2x4xf64> to tensor<?x?xf64>
%mat = sparse_tensor.convert %dense_mat : tensor<?x?xf64> to tensor<?x?xf64, #CSR64>
%mat_csc = graphblas.convert_layout %mat : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSC64>

// prints "mat [
//   [0, 1, 2, 0],
//   [0, 0, 0, 3],
// ]".
graphblas.print %dense_mat_fixed { strings = ["mat "] } : tensor<2x4xf64>

// prints "mat [
//   [0, 1, 2, 0],
//   [0, 0, 0, 3],
// ]".
graphblas.print %dense_mat { strings = ["mat "] } : tensor<?x?xf64>

// prints "mat [
//   [_, 1, 2, _],
//   [_, _, _, 3],
// ]".
graphblas.print %mat { strings = ["mat "] } : tensor<?x?xf64, #CSR64>

// prints "mat_csc [
//   [_, 1, 2, _],
//   [_, _, _, 3],
// ]".
graphblas.print %mat_csc { strings = ["mat_csc "] } : tensor<?x?xf64, #CSC64>

Attributes:

Attribute

MLIR Type

Description

strings

::mlir::ArrayAttr

string array attribute

Operands:

Operand

Description

values

any type

graphblas.print_tensor

Print tensor operation.

Syntax:

operation ::= `graphblas.print_tensor` $input attr-dict `:` type($input)

Prints the sparse components of a tensor to stdout. This is useful for debugging and testing.

graphblas.print_tensor %sparse_tensor { level=4 } : tensor<?x?xf64, #CSR64>

Attributes:

Attribute

MLIR Type

Description

level

::mlir::IntegerAttr

64-bit signless integer attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

graphblas.reduce_to_scalar_generic

Reduce to scalar generic operation.

Syntax:

operation ::= `graphblas.reduce_to_scalar_generic` $input attr-dict `:` type($input) `to` type($output) $extensions

Reduces a sparse tensor to a scalar according to the given aggregator block. If the tensor is a matrix, it must have a CSR sparsity or a CSC sparsity. The resulting scalar’s type will depend on the type of the input tensor.

Example:

%ci0 = arith.constant 0 : i64
%answer = graphblas.reduce_to_scalar_generic %sparse_vector : tensor<?xi64, #CV64> to i64 {
    graphblas.yield agg_identity %ci0 : i64
},  {
  ^bb0(%a : i64, %b : i64):
    %result = arith.addi %a, %b : i64
    graphblas.yield agg %result : i64
}

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

any type

graphblas.reduce_to_scalar

Reduce to scalar operation.

Syntax:

operation ::= `graphblas.reduce_to_scalar` $input attr-dict `:` type($input) `to` type($output)

Reduces a sparse tensor (CSR matrix, CSC matrix, or sparse vector) to a scalar according to the given aggregator. Matrices must have a CSR sparsity or a CSC sparsity.

The resulting scalar’s type will depend on the type of the input tensor, except for the cast of custom aggregators “count”, “argmin”, and “argmax”. For these cases, the output type is a 64-bit integer.

Example:

%answer_1 = graphblas.reduce_to_scalar %sparse_matrix { aggregator = "plus" } : tensor<?x?xf32, #CSR64> to f32
%answer_2 = graphblas.reduce_to_scalar %sparse_vector { aggregator = "argmax" } : tensor<?x?xf64, #CSR64> to i64

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

aggregator

::mlir::StringAttr

string attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

any type

graphblas.reduce_to_vector_generic

Matrix reduce to vector generic operation.

Syntax:

operation ::= `graphblas.reduce_to_vector_generic` $input (`,` $mask^)? attr-dict `:` type($input) (`,` type($mask)^)? `to` type($output) $extensions

Reduces a CSR or CSC matrix to a vector according to the given axis using the specified aggregator block.

If the axis attribute is 0, the input tensor will be reduced column-wise, so the resulting vector’s size must be the number of columns in the input tensor.

If the axis attribute is 1, the input tensor will be reduced row-wise, so the resulting vector’s size must be the number of rows in the input tensor.

A vector mask is allowed to limit the output.

Example:

%vec1 = graphblas.reduce_to_vector_generic %matrix_1 { axis = 0 } : tensor<7x9xf16, #CSR64> to tensor<9xf16, #CV64> {
      graphblas.yield agg_identity %cf0 : f16
  },  {
    ^bb0(%a : f16, %b : f16):
      %result = arith.addf %a, %b : f16
      graphblas.yield agg %result : f16
  }
%vec2 = graphblas.reduce_to_vector_generic %matrix_2, %mask { axis = 1, mask_complement = true } : tensor<7x9xi64, #CSR64>, tensor<7xi64, #CV64> to tensor<7xi64, #CV64> {
      graphblas.yield agg_identity %ci0 : i64
  },  {
    ^bb0(%a : i64, %b : i64):
      %cmp = arith.cmpi "slt", %a, %b : i64
      %result = arith.select %cmp, %a, %b : i64
      graphblas.yield agg %result : i64
  }

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

axis

::mlir::IntegerAttr

64-bit signless integer attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

input

2D tensor of any type values

mask

1D tensor of any type values

Results:

Result

Description

output

1D tensor of any type values

graphblas.reduce_to_vector

Matrix reduce to vector operation.

Syntax:

operation ::= `graphblas.reduce_to_vector` $input (`,` $mask^)? attr-dict `:` type($input) (`,` type($mask)^)? `to` type($output)

Reduces a CSR or CSC matrix to a vector according to the given aggregator and axis.

The resulting sparse vector’s element type varies according to the given aggregator. All monoids are allowed as aggregators. Additional custom aggregators are “count”, “argmin”, “argmax”, “first”, and “last”. The output’s element type is normally the same as the input, but “count”, “argmin”, and “argmax” will cause the output’s element to be a 64-bit integer.

If the axis attribute is 0, the input tensor will be reduced column-wise, so the resulting vector’s size must be the number of columns in the input tensor.

If the axis attribute is 1, the input tensor will be reduced row-wise, so the resulting vector’s size must be the number of rows in the input tensor.

A vector mask is allowed to limit the output.

Example:

%vec1 = graphblas.reduce_to_vector %matrix_1 { aggregator = "plus", axis = 0 } : tensor<7x9xf16, #CSR64> to tensor<9xf16, #CV64>
%vec2 = graphblas.reduce_to_vector %matrix_2 { aggregator = "count", axis = 1 } : tensor<7x9xf16, #CSR64> to tensor<7xf16, #CV64>
%vec3 = graphblas.reduce_to_vector %matrix_3, %mask { aggregator = "plus", axis = 1, mask_complement = false } : tensor<7x9xf32, #CSR64>, tensor<7xi64, #CV64> to tensor<7xi64, #CV64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

aggregator

::mlir::StringAttr

string attribute

axis

::mlir::IntegerAttr

64-bit signless integer attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

input

2D tensor of any type values

mask

1D tensor of any type values

Results:

Result

Description

output

1D tensor of any type values

graphblas.select_generic

Generic select operation.

Syntax:

operation ::= `graphblas.select_generic` $input attr-dict `:` type($input) `to` type($output) $extensions

Returns a new sparse tensor with a subset of elements from the given tensor.

The given tensor must be a sparse vector or a matrix with CSR or CSC sparsity. The resulting sparse tensors will have the same encoding as the input tensor.

Selector Example (upper triangle):

%result = graphblas.select_generic %sparse_tensor : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSR64> {
  ^bb0(%val: f64, %row: index, %col: index):
    %result = arith.cmpi "ugt", %col, %row : index
    return %result : i1
}

Selector with Thunk Example:

%thunk = constant 0.0 : f64
%result = graphblas.select_generic %sparse_tensor : tensor<?x?xf64, #CSR64>, f64 to tensor<?x?xf64, #CSR64> {
  ^bb0(%val: f64):
    %result = arith.cmpf "olt", %val, %thunk : f64
    return %result : i1
}

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.select_mask

Select part of a sparse tensor using a mask.

Syntax:

operation ::= `graphblas.select_mask` $input `,` $mask attr-dict `:` type($input) `,` type($mask) `to` type($output)

Given a sparse tensor and a mask of the same dimensions, returns a new sparse tensor with elements from the original tensor that have an entry present in the mask (i.e. return the intersection of the sparse structures).

There’s an optional boolean mask_complement attribute (which has a default value of false) that will make the op use the complement of the mask.

Example:

%vector_answer = graphblas.select_mask %vec, %msk : tensor<?xf64, #CV64>, tensor<?xi64, #CV64> to tensor<?xf64, #CV64>
%matrix_answer = grpahblas.select_mask %mat, %msk {mask_complement = true} : tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.select

Select operation.

Syntax:

operation ::= `graphblas.select` $input (`,` $thunks^)? attr-dict `:` type($input) (`,` type($thunks)^)? `to` type($output)

Returns a new sparse tensor with a subset of elements from the given tensor.

An element is only included in the resulting sparse tensor if the selector returns true. Selectors may be unary or binary, but must return boolean type.

Boolean selectors, e.g. “gt”, require a thunk. Some custom selectors, e.g. “probability” require a thunk and random number generator context.

The given tensor must be a sparse vector or a matrix with CSR or CSC sparsity. The resulting sparse tensors will have the same encoding as the input tensor.

Selector Example:

%result = graphblas.select %sparse_tensor { selector = "triu" } : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSR64>

Selector with Thunk Example:

%thunk = constant 0.0 : f64
%result = graphblas.select %sparse_tensor, %thunk { selector = "gt" } : tensor<?x?xf64, #CSR64>, f64 to tensor<?x?xf64, #CSR64>

Random Selector Example:

%percentage = constant 0.5 : f64
%result = graphblas.select %sparse_tensor, %percentage, %rng_context { selector = "probability" } : tensor<?x?xf64, #CSR64>, f64, !llvm.ptr<i8> to tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

selector

::mlir::StringAttr

string attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

thunks

any type

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.size

Return the size of a sparse vector.

Syntax:

operation ::= `graphblas.size` $input attr-dict `:` type($input)

Returns the size of a vector.

Example:

%size = graphblas.size %sparse_vector : tensor<?xf64, #CV64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D tensor of any type values

Results:

Result

Description

result

index

graphblas.to_coo

Extract coordinate tensors from sparse tensor.

Syntax:

operation ::= `graphblas.to_coo` $input attr-dict `:` type($input) `to` type($indices) `,` type($values)

Returns the indices as coordinates and associated values from a sparse tensor. The indices are returned as tensor<?x?xindex> while the values are returned as tensor<?xVALUE_TYPE>

%indices, %vals = graphblas.to_coo %v : tensor<?x?xf64, #CSR64> to tensor<?x?xindex>, tensor<?xf64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

Results:

Result

Description

indices

2D tensor of index values

values

1D tensor of any type values

graphblas.transpose

Transpose operation.

Syntax:

operation ::= `graphblas.transpose` $input attr-dict `:` type($input) `to` type($output)

Returns a new sparse matrix that’s the transpose of the input matrix. The given sparse tensor must be a matrix, i.e. have rank 2. The given tensor must have a CSR sparsity or a CSC sparsity. The output type must be CSR or CSC.

Example:

%a = graphblas.transpose %sparse_tensor : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSC64>
%b = graphblas.transpose %sparse_tensor : tensor<?x?xf64, #CSR64> to tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

2D tensor of any type values

Results:

Result

Description

output

2D tensor of any type values

graphblas.uniform_complement

Structural Complement filled with a uniform value.

Syntax:

operation ::= `graphblas.uniform_complement` $input `,` $value attr-dict `:` type($input) `,` type($value) `to` type($output)

Returns the complement of the structure of a vector or matrix, filled with a uniform scalar value.

Example:

%cf2 = arith.constant 2.0 : f32
%answer = graphblas.uniform_complement %mat, %cf2 : tensor<?x?xf64, #CSR64>, f32 to tensor<?x?xf32, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

1D/2D tensor of any type values

value

any type

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.union_generic

Element-wise union operation.

Syntax:

operation ::= `graphblas.union_generic` $a `,` $b attr-dict `:` `(` type($a) `,` type($b)  `)` `to` type($output) $extensions

Performs an element-wise union between two matrices or two vectors. The resulting sparse structure will be the union of the two input structures. When either object has a non-overlapping element, it is copied to the output. When both objects have an overlapping element in a cell, an operation combines the result according to the given binary operator.

Example:

%combined = graphblas.union_generic %A, %B : (
  tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64) to tensor<?x?xf64, #CSR64> {
  ^bb0(%a : f64, %b : f64):
    %result = arith.addf %a, %b : f64
    graphblas.yield mult %result : f64
}

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.union

Element-wise union operation.

Syntax:

operation ::= `graphblas.union` $a `,` $b (`,` $mask^)? attr-dict `:` `(` type($a) `,` type($b) (`,` type($mask)^)? `)` `to` type($output)

Performs an element-wise union between two matrices or two vectors. The resulting sparse structure will be the union of the two input structures. When either object has a non-overlapping element, it is copied to the output. When both objects have an overlapping element in a cell, an operation combines the result.

The allowable operators are monoids and the custom operators “first” and “second”.

The mask (if provided) must be the same format as the returned object. There’s an optional boolean mask_complement attribute (which has a default value of false) that will use the structural complement of the mask.

Example:

%combined = graphblas.union %A, %B { union_operator = "plus" } : (
    tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64>) to tensor<?x?xf64, #CSR64>

%combined = graphblas.union %A, %B, %mask { union_operator = "plus", mask_complement = true } : (
    tensor<?x?xf64, #CSR64>, tensor<?x?xf64, #CSR64>, tensor<?x?xi32, #CSR64>) to tensor<?x?xf64, #CSR64>

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

union_operator

::mlir::StringAttr

string attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

a

1D/2D tensor of any type values

b

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

Results:

Result

Description

output

1D/2D tensor of any type values

graphblas.update_generic

Update operation handling accumulation, mask, and replacement.

Syntax:

operation ::= `graphblas.update_generic` $input `->` $output (`(` $mask^ `)`)? attr-dict `:` type($input) `->` type($output) (`(` type($mask)^ `)`)? $extensions

Updates the output tensor based on the input and desired accumulation, mask, and replacement. This returns zero values and modifies the output in place. The given tensors must be sparse. Accumulation is performed based on the given accumulator block.

There’s an optional boolean mask_complement attribute (which has a default value of false) that will make the op use the complement of the mask.

Example:

graphblas.update_generic %other_mat -> %mat(%mask) { replace = true, mask_complement = true } :
    tensor<?x?xi64, #CSR64> -> tensor<?x?xf64, #CSR64>(tensor<?x?xf64, #CSR64>) {
    ^bb0(%a : f64, %b : f64):
      %result = arith.addf %a, %b : f64
      graphblas.yield accumulate %result : f64
}
graphblas.update_generic %other_vec -> %vec : tensor<?xi64, #CV64> -> tensor<?xi64, #CV64> {
    ^bb0(%a : i64, %b : i64):
      %result = arith.muli %a, %b : i64
      graphblas.yield accumulate %result : i64
}

Attributes:

Attribute

MLIR Type

Description

replace

::mlir::BoolAttr

bool attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

output

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

graphblas.update

Update operation handling accumulation, mask, and replacement.

Syntax:

operation ::= `graphblas.update` $input `->` $output (`(` $mask^ `)`)? attr-dict `:` type($input) `->` type($output) (`(` type($mask)^ `)`)?

Updates the output tensor based on the input and desired accumulation, mask, and replacement. This returns zero values and modifies the output in place. The supported accumulate operators are “plus”, “times”, “min”, “max”, “first”, and “second”. The given tensors must be sparse.

There’s an optional boolean mask_complement attribute (which has a default value of false) that will make the op use the structural complement of the mask.

Example:

graphblas.update %other_mat -> %mat(%mask) { accumulate_operator = "times", replace = true, mask_complement = true } : tensor<?x?xi64, #CSR64> -> tensor<?x?xi64, #CSR64>(tensor<?x?xi64, #CSR64>)
graphblas.update %other_vec -> %vec { accumulate_operator = "plus" } : tensor<?xi64, #CV64> -> tensor<?xi64, #CV64>
graphblas.update %other_mat -> %mat(%mask) { accumulate_operator = "max", replace = true } : tensor<?x?xi64, #CSC64> -> tensor<?x?xi64, #CSC64>(tensor<?x?xi64, #CSC64>)

Attributes:

Attribute

MLIR Type

Description

accumulate_operator

::mlir::StringAttr

string attribute

replace

::mlir::BoolAttr

bool attribute

mask_complement

::mlir::BoolAttr

bool attribute

Operands:

Operand

Description

input

1D/2D tensor of any type values

output

1D/2D tensor of any type values

mask

1D/2D tensor of any type values

graphblas.yield

GraphBLAS yield operation.

Syntax:

operation ::= `graphblas.yield` $kind $values attr-dict `:` type($values)

graphblas.yield is a special terminator operation for blocks inside regions in several graphblas operations. It returns a value to the enclosing op, with a meaning that depends on the op.

Special terminator operation for blocks inside regions in several GraphBLAS dialect operations, e.g. graphblas.matrix_multiply_generic. It returns a value to the enclosing op, with a meaning that depends on the required “kind” attribute. It must be one of the following:

  • transform_in_a

  • transform_in_b

  • transform_out

  • select_in_a

  • select_in_b

  • select_out

  • add_identity

  • add

  • mult

  • accumulate

Example:

graphblas.yield transform_out %f0 : f64

Traits: ReturnLike, Terminator

Interfaces: NoSideEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

Attribute

MLIR Type

Description

kind

::mlir::graph blas::YieldKindAttr

allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11

Operands:

Operand

Description

values

any type