diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index baee8a61d..2285f5f33 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -2567,9 +2567,9 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) { return doExpr(subExpr); } case CastKind::CK_HLSLVectorToMatrixCast: { - // If target type is already an 1xN matrix type, we just return the + // If target type is already an 1xN or Mx1 matrix type, we just return the // underlying vector. - if (is1xNMatrix(toType)) + if (is1xNMatrix(toType) || isMx1Matrix(toType)) return doExpr(subExpr); // A vector can have no more than 4 elements. The only remaining case diff --git a/tools/clang/test/CodeGenSPIRV/cast.vec-to-mat.implicit.hlsl b/tools/clang/test/CodeGenSPIRV/cast.vec-to-mat.implicit.hlsl new file mode 100644 index 000000000..8955a9e2a --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/cast.vec-to-mat.implicit.hlsl @@ -0,0 +1,21 @@ +// Run: %dxc -T ps_6_0 -E main + +float4 main(float4 input : A) : SV_Target { +// CHECK: [[vec:%\d+]] = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4 +// CHECK: OpStore %var1 [[vec]] + float4 var1 = float4(1,2,3,4); + +// CHECK-NEXT: [[vec1:%\d+]] = OpLoad %v4float %var1 +// CHECK-NEXT: OpStore %var2 [[vec1]] + float4x1 var2 = var1; + +// CHECK-NEXT: [[vec2:%\d+]] = OpLoad %v4float %input +// CHECK-NEXT: OpStore %var3 [[vec2]] + float1x4 var3 = input; + +// CHECK-NEXT: [[vec3:%\d+]] = OpLoad %v4float %input +// CHECK-NEXT: OpStore %var4 [[vec3]] + float4x1 var4 = input; + + return input; +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index 9d5727629..682d24e5b 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -478,6 +478,9 @@ TEST_F(FileTest, CastFlatConversionDecomposeVector) { TEST_F(FileTest, CastExplicitVecToMat) { runFileTest("cast.vec-to-mat.explicit.hlsl"); } +TEST_F(FileTest, CastImplicitVecToMat) { + runFileTest("cast.vec-to-mat.implicit.hlsl"); +} TEST_F(FileTest, CastMatrixToVector) { runFileTest("cast.mat-to-vec.hlsl"); } TEST_F(FileTest, CastBitwidth) { runFileTest("cast.bitwidth.hlsl"); }