Skip to content

Commit be4c5ad

Browse files
gaofangfrankpsoni2628
authored andcommitted
[mlir][vector] Add scalable vectors support to OuterProductOp
This will probably be the first in a series of patches that tries to enable code generation for ARM SME (extension of SVE). Since SME's core operation is the outer product instruction, I figured that it would probably be a good idea to enable the outer product operation to properly accept and generate scalable vectors. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D138718
1 parent f88c6b9 commit be4c5ad

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2657,10 +2657,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
26572657
if (!vLHS)
26582658
return parser.emitError(parser.getNameLoc(),
26592659
"expected vector type for operand #1");
2660-
VectorType resType =
2661-
vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2662-
vLHS.getElementType())
2663-
: VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2660+
2661+
unsigned numScalableDims = vLHS.getNumScalableDims();
2662+
VectorType resType;
2663+
if (vRHS) {
2664+
numScalableDims += vRHS.getNumScalableDims();
2665+
resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2666+
vLHS.getElementType(), numScalableDims);
2667+
} else {
2668+
// Scalar RHS operand
2669+
resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
2670+
numScalableDims);
2671+
}
26642672

26652673
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
26662674
result.attributes.append(
@@ -2696,6 +2704,9 @@ LogicalResult OuterProductOp::verify() {
26962704
return emitOpError("expected #1 operand dim to match result dim #1");
26972705
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
26982706
return emitOpError("expected #2 operand dim to match result dim #2");
2707+
if (vRHS.isScalable() != vLHS.isScalable())
2708+
return emitOpError("expected either all or none of vector operands #1 "
2709+
"and #2 to be scalable");
26992710
} else {
27002711
// An AXPY operation.
27012712
if (vRES.getRank() != 1)

0 commit comments

Comments
 (0)