Skip to content

Commit f94131a

Browse files
[mlir][vector] Support multiple result types in vector.mask
The verifier already had support for multiple result types, but the op definition assumed a single, optional result. Differential Revision: https://reviews.llvm.org/D141683
1 parent f601039 commit f94131a

2 files changed

Lines changed: 20 additions & 13 deletions

File tree

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2287,10 +2287,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
22872287
The `vector.mask` is a `MaskingOpInterface` operation that predicates the
22882288
execution of another operation. It takes an `i1` vector mask and an
22892289
optional passthru vector as arguments.
2290-
A `vector.yield`-terminated region encloses the operation to be masked.
2291-
Values used within the region are captured from above. Only one *maskable*
2292-
operation can be masked with a `vector.mask` operation at a time. An
2293-
operation is *maskable* if it implements the `MaskableOpInterface`.
2290+
2291+
A implicitly `vector.yield`-terminated region encloses the operation to be
2292+
masked. Values used within the region are captured from above. Only one
2293+
*maskable* operation can be masked with a `vector.mask` operation at a time.
2294+
An operation is *maskable* if it implements the `MaskableOpInterface`. The
2295+
terminator yields all results of the maskable operation to the result of
2296+
this operation.
22942297

22952298
The vector mask argument holds a bit for each vector lane and determines
22962299
which vector lanes should execute the maskable operation and which ones
@@ -2321,23 +2324,27 @@ def Vector_MaskOp : Vector_Op<"mask", [
23212324
```
23222325
vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
23232326
```
2327+
2328+
```
2329+
vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
2330+
```
23242331
}];
23252332

23262333
// TODO: Support multiple results and passthru values.
23272334
let arguments = (ins VectorOf<[I1]>:$mask,
23282335
Optional<AnyType>:$passthru);
2329-
let results = (outs Optional<AnyType>:$results);
2336+
let results = (outs Variadic<AnyType>:$results);
23302337
let regions = (region SizedRegion<1>:$maskRegion);
23312338

23322339
let skipDefaultBuilders = 1;
23332340
let builders = [
23342341
OpBuilder<(ins "Value":$mask,
23352342
CArg<"function_ref<void(OpBuilder &, Location)>",
23362343
"buildTerminatedBody">:$maskRegion)>,
2337-
OpBuilder<(ins "Type":$resultType, "Value":$mask,
2344+
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
23382345
CArg<"function_ref<void(OpBuilder &, Location)>",
23392346
"buildTerminatedBody">:$maskRegion)>,
2340-
OpBuilder<(ins "Type":$resultType, "Value":$mask,
2347+
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
23412348
"Value":$passthru,
23422349
CArg<"function_ref<void(OpBuilder &, Location)>",
23432350
"buildTerminatedBody">:$maskRegion)>

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5288,20 +5288,20 @@ void MaskOp::build(
52885288
}
52895289

52905290
void MaskOp::build(
5291-
OpBuilder &builder, OperationState &result, Type resultType, Value mask,
5292-
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5293-
build(builder, result, resultType, mask, /*passthru=*/Value(),
5291+
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5292+
Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5293+
build(builder, result, resultTypes, mask, /*passthru=*/Value(),
52945294
maskRegionBuilder);
52955295
}
52965296

52975297
void MaskOp::build(
5298-
OpBuilder &builder, OperationState &result, Type resultType, Value mask,
5299-
Value passthru,
5298+
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5299+
Value mask, Value passthru,
53005300
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
53015301
build(builder, result, mask, maskRegionBuilder);
53025302
if (passthru)
53035303
result.addOperands(passthru);
5304-
result.addTypes(resultType);
5304+
result.addTypes(resultTypes);
53055305
}
53065306

53075307
ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {

0 commit comments

Comments
 (0)