diff --git a/strings/base_implements.h b/strings/base_implements.h index ca69bc38..0c0ff068 100644 --- a/strings/base_implements.h +++ b/strings/base_implements.h @@ -242,6 +242,12 @@ WINRT_EXPORT namespace winrt return &static_cast>*>(get_abi(from))->shim(); } + template + D* get_self(com_ptr const& from) noexcept + { + return static_cast(static_cast*>(from.get())); + } + template [[deprecated]] D* from_abi(I const& from) noexcept { diff --git a/strings/base_meta.h b/strings/base_meta.h index 54d41e61..f474fced 100644 --- a/strings/base_meta.h +++ b/strings/base_meta.h @@ -13,6 +13,9 @@ WINRT_EXPORT namespace winrt template struct com_ptr; + template + D* get_self(com_ptr const& from) noexcept; + namespace param { template diff --git a/test/old_tests/UnitTests/interop.cpp b/test/old_tests/UnitTests/interop.cpp index 1f3ae56f..357904a1 100644 --- a/test/old_tests/UnitTests/interop.cpp +++ b/test/old_tests/UnitTests/interop.cpp @@ -7,6 +7,10 @@ using namespace Windows::Foundation; namespace { + struct IClassicComInterface : ::IUnknown {}; + + struct ClassicCom : implements {}; + struct Stringable : implements { Stringable(std::wstring_view const& value = L"Stringable") : m_value(value) @@ -30,8 +34,16 @@ namespace object->AddRef(); return object->Release(); } + + template + uint32_t get_ref_count(com_ptr const& object) + { + return get_ref_count(object.get()); + } } +template <> inline constexpr winrt::guid winrt::impl::guid_v{ 0xc136bb75, 0xbc03, 0x41a6, { 0xa5, 0xdc, 0x5e, 0xfa, 0x67, 0x92, 0x4e, 0xbf } }; + TEST_CASE("interop") { uint32_t const before = get_module_lock(); @@ -108,6 +120,43 @@ TEST_CASE("self") REQUIRE(get_ref_count(object) == 1); object = nullptr; + strong = weak.get(); + REQUIRE(!strong); +} + +TEST_CASE("self_classic_com") +{ + com_ptr strong = make_self(); + + REQUIRE(get_ref_count(strong.get()) == 1); + + com_ptr object = strong.as(); + + REQUIRE(get_ref_count(strong.get()) == 2); + + ClassicCom* ptr = get_self(object); + REQUIRE(ptr == strong.get()); + + REQUIRE(get_ref_count(strong.get()) == 2); + strong = nullptr; + REQUIRE(get_ref_count(object) == 1); + + strong = get_self(object)->get_strong(); + REQUIRE(get_ref_count(object) == 2); + strong = nullptr; + REQUIRE(get_ref_count(object) == 1); + + weak_ref weak = get_self(object)->get_weak(); + REQUIRE(get_ref_count(object) == 1); // <-- still just one! + + strong = weak.get(); + REQUIRE(strong); + REQUIRE(get_ref_count(object) == 2); + + strong = nullptr; + REQUIRE(get_ref_count(object) == 1); + object = nullptr; + strong = weak.get(); REQUIRE(!strong); } \ No newline at end of file