diff --git a/tools/clang/lib/AST/ExprConstant.cpp b/tools/clang/lib/AST/ExprConstant.cpp index 7275f2e08..8543a5bc0 100644 --- a/tools/clang/lib/AST/ExprConstant.cpp +++ b/tools/clang/lib/AST/ExprConstant.cpp @@ -76,6 +76,31 @@ static const FunctionDecl *GetCallExprFunction(const CallExpr *CE) { return FDecl; } + +// Returns true if the given InitListExpr is for constructing a HLSL vector +// with the matching number of initializers and each initializer has the +// matching element type. +static bool IsHLSLVecInitList(const Expr* expr) { + if (const auto* initExpr = dyn_cast(expr)) { + const QualType vecType = initExpr->getType(); + if (!hlsl::IsHLSLVecType(vecType)) + return false; + + const uint32_t size = hlsl::GetHLSLVecSize(vecType); + const QualType elemType = hlsl::GetHLSLVecElementType(vecType).getCanonicalType(); + + if (initExpr->getNumInits() != size) + return false; + + for (uint32_t i = 0; i < size; ++i) + if (initExpr->getInit(i)->getType().getCanonicalType() != elemType) + return false; + + return true; + } + + return false; +} // HLSL Change Ends @@ -4254,7 +4279,7 @@ public: bool VisitInitListExpr(const InitListExpr *E) { if (E->getNumInits() == 0) return DerivedZeroInitialization(E); - if (Info.getLangOpts().HLSL) return Error(E); // HLSL Change + if (Info.getLangOpts().HLSL && !IsHLSLVecInitList(E)) return Error(E); // HLSL Change if (E->getNumInits() == 1) return StmtVisitorTy::Visit(E->getInit(0)); return Error(E); @@ -4295,9 +4320,13 @@ public: } bool VisitCastExpr(const CastExpr *E) { - if (Info.getLangOpts().HLSL && E->getSubExpr()->getStmtClass() == Stmt::InitListExprClass) { // HLSL Change - return Error(E); + // HLSL Change Begins + if (Info.getLangOpts().HLSL) { + const auto* subExpr = E->getSubExpr(); + if (subExpr->getStmtClass() == Stmt::InitListExprClass && !IsHLSLVecInitList(subExpr)) + return Error(E); } + // HLSL Change Ends switch (E->getCastKind()) { default: break; @@ -5557,7 +5586,7 @@ public: } } bool VisitInitListExpr(const InitListExpr *E) { - if (Info.getLangOpts().HLSL) return Error(E); // HLSL Change + if (Info.getLangOpts().HLSL && !IsHLSLVecInitList(E)) return Error(E); // HLSL Change return VisitConstructExpr(E); } bool VisitCXXConstructExpr(const CXXConstructExpr *E) { @@ -5643,6 +5672,7 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr* E) { QualType SETy = SE->getType(); switch (E->getCastKind()) { + case CK_HLSLVectorSplat: // HLSL Change case CK_VectorSplat: { APValue Val = APValue(); if (SETy->isIntegerType()) { @@ -8476,12 +8506,10 @@ static bool Evaluate(APValue &Result, EvalInfo &Info, const Expr *E) { // In C, function designators are not lvalues, but we evaluate them as if they // are. // HLSL Change Begins. - if (Info.getLangOpts().HLSL && E->getStmtClass() == Stmt::InitListExprClass) { // HLSL Change - if (hlsl::IsHLSLVecType(E->getType())) { - if (EvaluateVector(E, Result, Info)) + if (Info.getLangOpts().HLSL) { + if (E->isRValue() && hlsl::IsHLSLVecType(E->getType()) && EvaluateVector(E, Result, Info)) return true; - } - if (!E->getType()->isScalarType()) + if (E->getStmtClass() == Stmt::InitListExprClass && !E->getType()->isScalarType()) return false; } // HLSL Change Ends. diff --git a/tools/clang/test/CodeGenSPIRV/cast.vector.splat.hlsl b/tools/clang/test/CodeGenSPIRV/cast.vector.splat.hlsl index ed2d1ebc9..aaa7516c3 100644 --- a/tools/clang/test/CodeGenSPIRV/cast.vector.splat.hlsl +++ b/tools/clang/test/CodeGenSPIRV/cast.vector.splat.hlsl @@ -6,9 +6,10 @@ void main() { // CHECK-LABEL: %bb_entry = OpLabel // From constant -// CHECK: OpStore %vf4 [[v4f32c]] +// CHECK: %vf4 = OpVariable %_ptr_Function_v4float Function [[v4f32c]] float4 vf4 = 1; -// CHECK-NEXT: [[v3f32c:%\d+]] = OpCompositeConstruct %v3float %float_2 %float_2 %float_2 + +// CHECK: [[v3f32c:%\d+]] = OpCompositeConstruct %v3float %float_2 %float_2 %float_2 // CHECK-NEXT: OpStore %vf3 [[v3f32c]] float3 vf3; vf3 = float1(2); diff --git a/tools/clang/test/CodeGenSPIRV/var.init.hlsl b/tools/clang/test/CodeGenSPIRV/var.init.hlsl index 0140cc888..db3ccdabe 100644 --- a/tools/clang/test/CodeGenSPIRV/var.init.hlsl +++ b/tools/clang/test/CodeGenSPIRV/var.init.hlsl @@ -16,20 +16,11 @@ float4 main(float component: COLOR) : SV_TARGET { // CHECK-LABEL: %bb_entry = OpLabel -// CHECK-NEXT: %a = OpVariable %_ptr_Function_int Function %int_0 -// CHECK-NEXT: %b = OpVariable %_ptr_Function_int Function - -// CHECK-NEXT: %i = OpVariable %_ptr_Function_float Function %float_3 -// CHECK-NEXT: %j = OpVariable %_ptr_Function_float Function - -// CHECK-NEXT: %m = OpVariable %_ptr_Function_v4float Function -// CHECK-NEXT: %n = OpVariable %_ptr_Function_v4float Function -// CHECK-NEXT: %o = OpVariable %_ptr_Function_v4float Function - -// CHECK-NEXT: %p = OpVariable %_ptr_Function_v2int Function [[int2constant]] -// CHECK-NEXT: %q = OpVariable %_ptr_Function_v3int Function - -// CHECK-NEXT: %x = OpVariable %_ptr_Function_uint Function +// CHECK: %a = OpVariable %_ptr_Function_int Function %int_0 +// CHECK: %i = OpVariable %_ptr_Function_float Function %float_3 +// CHECK: %m = OpVariable %_ptr_Function_v4float Function [[float4constant]] +// CHECK: %p = OpVariable %_ptr_Function_v2int Function [[int2constant]] +// CHECK: %x = OpVariable %_ptr_Function_uint Function %uint_1 // Initializer already attached to the var definition int a = 0; // From constant @@ -43,9 +34,8 @@ float4 main(float component: COLOR) : SV_TARGET { // CHECK-NEXT: OpStore %j [[component0]] float j = component; // From stage variable -// CHECK-NEXT: OpStore %m [[float4constant]] float4 m = float4(1.0, 2.0, 3.0, 4.0); // All components are constants -// CHECK-NEXT: [[j0:%\d+]] = OpLoad %float %j +// CHECK: [[j0:%\d+]] = OpLoad %float %j // CHECK-NEXT: [[j1:%\d+]] = OpLoad %float %j // CHECK-NEXT: [[j2:%\d+]] = OpLoad %float %j // CHECK-NEXT: [[j3:%\d+]] = OpLoad %float %j @@ -65,7 +55,6 @@ float4 main(float component: COLOR) : SV_TARGET { // CHECK-NEXT: OpStore %q [[qinit]] int3 q = {4, b, a}; // Mixed cases -// CHECK-NEXT: OpStore %x %uint_1 uint1 x = uint1(1); // Special case: vector of size 1 return o;