diff --git a/toolkit/components/contentanalysis/ContentAnalysis.cpp b/toolkit/components/contentanalysis/ContentAnalysis.cpp index 35770c048cad..80c79ca831c3 100644 --- a/toolkit/components/contentanalysis/ContentAnalysis.cpp +++ b/toolkit/components/contentanalysis/ContentAnalysis.cpp @@ -82,6 +82,27 @@ static nsresult GetFileDisplayName(const nsString& aFilePath, return file->GetDisplayName(aFileDisplayName); } +uint32_t ResponseResultToAcknowledgementResult(uint32_t responseResult) { + switch (responseResult) { + case nsIContentAnalysisResponse::REPORT_ONLY: + return nsIContentAnalysisAcknowledgement::REPORT_ONLY; + case nsIContentAnalysisResponse::WARN: + return nsIContentAnalysisAcknowledgement::WARN; + case nsIContentAnalysisResponse::BLOCK: + return nsIContentAnalysisAcknowledgement::BLOCK; + case nsIContentAnalysisResponse::ALLOW: + return nsIContentAnalysisAcknowledgement::ALLOW; + case nsIContentAnalysisResponse::ACTION_UNSPECIFIED: + return nsIContentAnalysisAcknowledgement::ACTION_UNSPECIFIED; + default: + LOGE( + "ResponseResultToAcknowledgementResult got unexpected responseResult " + "%d", + responseResult); + return nsIContentAnalysisAcknowledgement::ACTION_UNSPECIFIED; + } +} + } // anonymous namespace namespace mozilla::contentanalysis { @@ -374,7 +395,8 @@ static void LogRequest( } ContentAnalysisResponse::ContentAnalysisResponse( - content_analysis::sdk::ContentAnalysisResponse&& aResponse) { + content_analysis::sdk::ContentAnalysisResponse&& aResponse) + : mHasAcknowledged(false) { mAction = nsIContentAnalysisResponse::ACTION_UNSPECIFIED; for (const auto& result : aResponse.results()) { if (!result.has_status() || @@ -400,7 +422,7 @@ ContentAnalysisResponse::ContentAnalysisResponse( ContentAnalysisResponse::ContentAnalysisResponse( unsigned long aAction, const nsACString& aRequestToken) - : mAction(aAction), mRequestToken(aRequestToken) {} + : mAction(aAction), mRequestToken(aRequestToken), mHasAcknowledged(false) {} /* static */ already_AddRefed ContentAnalysisResponse::FromProtobuf( @@ -531,7 +553,23 @@ static void LogAcknowledgement( } void ContentAnalysisResponse::SetOwner(RefPtr aOwner) { - mOwner = aOwner; + mOwner = std::move(aOwner); +} + +ContentAnalysisAcknowledgement::ContentAnalysisAcknowledgement( + unsigned long aResult, unsigned long aFinalAction) + : mResult(aResult), mFinalAction(aFinalAction) {} + +NS_IMETHODIMP +ContentAnalysisAcknowledgement::GetResult(uint32_t* aResult) { + *aResult = mResult; + return NS_OK; +} + +NS_IMETHODIMP +ContentAnalysisAcknowledgement::GetFinalAction(uint32_t* aFinalAction) { + *aFinalAction = mFinalAction; + return NS_OK; } namespace { @@ -565,6 +603,8 @@ NS_IMPL_CLASSINFO(ContentAnalysisRequest, nullptr, 0, {0}); NS_IMPL_ISUPPORTS_CI(ContentAnalysisRequest, nsIContentAnalysisRequest); NS_IMPL_CLASSINFO(ContentAnalysisResponse, nullptr, 0, {0}); NS_IMPL_ISUPPORTS_CI(ContentAnalysisResponse, nsIContentAnalysisResponse); +NS_IMPL_ISUPPORTS(ContentAnalysisAcknowledgement, + nsIContentAnalysisAcknowledgement); NS_IMPL_ISUPPORTS(ContentAnalysisCallback, nsIContentAnalysisCallback); NS_IMPL_ISUPPORTS(ContentAnalysisResult, nsIContentAnalysisResult); NS_IMPL_ISUPPORTS(ContentAnalysis, nsIContentAnalysis, ContentAnalysis); @@ -672,8 +712,8 @@ RefPtr ContentAnalysis::GetContentAnalysisFromService() { } nsresult ContentAnalysis::RunAnalyzeRequestTask( - RefPtr aRequest, - RefPtr aCallback) { + const RefPtr& aRequest, bool aAutoAcknowledge, + const RefPtr& aCallback) { nsresult rv = NS_ERROR_FAILURE; auto callbackCopy = aCallback; auto se = MakeScopeExit([&] { @@ -687,14 +727,11 @@ nsresult ContentAnalysis::RunAnalyzeRequestTask( rv = ConvertToProtobuf(aRequest, &pbRequest); NS_ENSURE_SUCCESS(rv, rv); - LOGD("Issuing ContentAnalysisRequest"); - LogRequest(&pbRequest); - nsCString requestToken; nsMainThreadPtrHandle callbackHolderCopy( new nsMainThreadPtrHolder( "content analysis callback", aCallback)); - CallbackData callbackData(std::move(callbackHolderCopy)); + CallbackData callbackData(std::move(callbackHolderCopy), aAutoAcknowledge); rv = aRequest->GetRequestToken(requestToken); NS_ENSURE_SUCCESS(rv, rv); { @@ -708,117 +745,18 @@ nsresult ContentAnalysis::RunAnalyzeRequestTask( mCaClientPromise->Then( GetCurrentSerialEventTarget(), __func__, [requestToken, pbRequest = std::move(pbRequest)]( - std::shared_ptr client) { + std::shared_ptr client) mutable { // The content analysis call is synchronous so run in the background. NS_DispatchBackgroundTask( NS_NewCancelableRunnableFunction( __func__, - [requestToken, pbRequest = std::move(pbRequest), client] { - RefPtr owner = - GetContentAnalysisFromService(); - if (!owner) { - // May be shutting down - return; - } - - if (!client) { - owner->CancelWithError(std::move(requestToken), - NS_ERROR_NOT_AVAILABLE); - return; - } - { - auto callbackMap = owner->mCallbackMap.Lock(); - if (!callbackMap->Contains(requestToken)) { - LOGD( - "RunAnalyzeRequestTask token %s has already been " - "cancelled - not issuing request", - requestToken.get()); - return; - } - } - - // Run request, then dispatch back to main thread to resolve - // aCallback - content_analysis::sdk::ContentAnalysisResponse pbResponse; - int err = client->Send(pbRequest, &pbResponse); - if (err != 0) { - LOGE("RunAnalyzeRequestTask client transaction failed"); - owner->CancelWithError(std::move(requestToken), - NS_ERROR_FAILURE); - return; - } - LOGD("Content analysis client transaction succeeded"); - LogResponse(&pbResponse); - NS_DispatchToMainThread(NS_NewCancelableRunnableFunction( - "ContentAnalysis::RunAnalyzeRequestTask::HandleResponse", - [pbResponse = std::move(pbResponse)]() mutable { - RefPtr owner = - GetContentAnalysisFromService(); - if (!owner) { - // May be shutting down - return; - } - - RefPtr response = - ContentAnalysisResponse::FromProtobuf( - std::move(pbResponse)); - if (!response) { - LOGE("Content analysis got invalid response!"); - return; - } - nsCString responseRequestToken; - nsresult requestRv = - response->GetRequestToken(responseRequestToken); - if (NS_FAILED(requestRv)) { - LOGE( - "Content analysis couldn't get request token " - "from response!"); - return; - } - - Maybe maybeCallbackData; - { - auto callbackMap = owner->mCallbackMap.Lock(); - maybeCallbackData = - callbackMap->Extract(responseRequestToken); - } - if (maybeCallbackData.isNothing()) { - LOGD( - "Content analysis did not find callback for " - "token %s", - responseRequestToken.get()); - return; - } - response->SetOwner(owner); - if (maybeCallbackData->Canceled()) { - // request has already been cancelled, so there's - // nothing to do - LOGD( - "Content analysis got response but ignoring " - "because it was already cancelled for token %s", - responseRequestToken.get()); - return; - } - - LOGD( - "Content analysis resolving response promise for " - "token %s", - responseRequestToken.get()); - nsCOMPtr obsServ = - mozilla::services::GetObserverService(); - - obsServ->NotifyObservers(response, "dlp-response", - nullptr); - - nsMainThreadPtrHandle - callbackHolder = - maybeCallbackData->TakeCallbackHolder(); - callbackHolder->ContentResult(response); - })); + [requestToken, pbRequest = std::move(pbRequest), + client = std::move(client)]() mutable { + DoAnalyzeRequest(requestToken, std::move(pbRequest), client); }), NS_DISPATCH_EVENT_MAY_BLOCK); }, - [requestToken](nsresult rv) { + [requestToken](nsresult rv) mutable { LOGD("RunAnalyzeRequestTask failed to get client"); RefPtr owner = GetContentAnalysisFromService(); if (!owner) { @@ -831,9 +769,127 @@ nsresult ContentAnalysis::RunAnalyzeRequestTask( return rv; } +void ContentAnalysis::DoAnalyzeRequest( + nsCString aRequestToken, + content_analysis::sdk::ContentAnalysisRequest&& aRequest, + const std::shared_ptr& aClient) { + MOZ_ASSERT(!NS_IsMainThread()); + RefPtr owner = + ContentAnalysis::GetContentAnalysisFromService(); + if (!owner) { + // May be shutting down + return; + } + + if (!aClient) { + owner->CancelWithError(std::move(aRequestToken), NS_ERROR_NOT_AVAILABLE); + return; + } + { + auto callbackMap = owner->mCallbackMap.Lock(); + if (!callbackMap->Contains(aRequestToken)) { + LOGD( + "RunAnalyzeRequestTask token %s has already been " + "cancelled - not issuing request", + aRequestToken.get()); + return; + } + } + + // Run request, then dispatch back to main thread to resolve + // aCallback + content_analysis::sdk::ContentAnalysisResponse pbResponse; + int err = aClient->Send(aRequest, &pbResponse); + if (err != 0) { + LOGE("RunAnalyzeRequestTask client transaction failed"); + owner->CancelWithError(std::move(aRequestToken), NS_ERROR_FAILURE); + return; + } + LOGD("Content analysis client transaction succeeded"); + LogResponse(&pbResponse); + NS_DispatchToMainThread(NS_NewCancelableRunnableFunction( + "ContentAnalysis::RunAnalyzeRequestTask::HandleResponse", + [pbResponse = std::move(pbResponse)]() mutable { + RefPtr owner = GetContentAnalysisFromService(); + if (!owner) { + // May be shutting down + return; + } + + RefPtr response = + ContentAnalysisResponse::FromProtobuf(std::move(pbResponse)); + if (!response) { + LOGE("Content analysis got invalid response!"); + return; + } + nsCString responseRequestToken; + nsresult requestRv = response->GetRequestToken(responseRequestToken); + if (NS_FAILED(requestRv)) { + LOGE( + "Content analysis couldn't get request token " + "from response!"); + return; + } + + Maybe maybeCallbackData; + { + auto callbackMap = owner->mCallbackMap.Lock(); + maybeCallbackData = callbackMap->Extract(responseRequestToken); + } + if (maybeCallbackData.isNothing()) { + LOGD( + "Content analysis did not find callback for " + "token %s", + responseRequestToken.get()); + return; + } + response->SetOwner(owner); + if (maybeCallbackData->Canceled()) { + // request has already been cancelled, so there's + // nothing to do + LOGD( + "Content analysis got response but ignoring " + "because it was already cancelled for token %s", + responseRequestToken.get()); + // Note that we always acknowledge here, even if + // autoAcknowledge isn't set. From the caller's perspective + // the request has already been + // cancelled and the caller isn't going to get any + // notification that the DLP agent sent a response, so + // the caller wouldn't be able to send an + // acknowledgment. + auto acknowledgement = MakeRefPtr( + nsIContentAnalysisAcknowledgement::TOO_LATE, + nsIContentAnalysisAcknowledgement::BLOCK); + response->Acknowledge(acknowledgement); + return; + } + + LOGD( + "Content analysis resolving response promise for " + "token %s", + responseRequestToken.get()); + nsCOMPtr obsServ = + mozilla::services::GetObserverService(); + + obsServ->NotifyObservers(response, "dlp-response", nullptr); + if (maybeCallbackData->AutoAcknowledge()) { + uint32_t action = response->GetAction(); + auto acknowledgement = MakeRefPtr( + nsIContentAnalysisAcknowledgement::SUCCESS, + ResponseResultToAcknowledgementResult(action)); + response->Acknowledge(acknowledgement); + } + + nsMainThreadPtrHandle callbackHolder = + maybeCallbackData->TakeCallbackHolder(); + callbackHolder->ContentResult(response); + })); +} + NS_IMETHODIMP ContentAnalysis::AnalyzeContentRequest(nsIContentAnalysisRequest* aRequest, - JSContext* aCx, + bool aAutoAcknowledge, JSContext* aCx, mozilla::dom::Promise** aPromise) { RefPtr promise; nsresult rv = MakePromise(aCx, &promise); @@ -841,12 +897,13 @@ ContentAnalysis::AnalyzeContentRequest(nsIContentAnalysisRequest* aRequest, RefPtr callbackPtr = new ContentAnalysisCallback(promise); promise.forget(aPromise); - return AnalyzeContentRequestCallback(aRequest, callbackPtr.get()); + return AnalyzeContentRequestCallback(aRequest, aAutoAcknowledge, + callbackPtr.get()); } NS_IMETHODIMP ContentAnalysis::AnalyzeContentRequestCallback( - nsIContentAnalysisRequest* aRequest, + nsIContentAnalysisRequest* aRequest, bool aAutoAcknowledge, nsIContentAnalysisCallback* aCallback) { NS_ENSURE_ARG(aRequest); NS_ENSURE_ARG(aCallback); @@ -862,7 +919,7 @@ ContentAnalysis::AnalyzeContentRequestCallback( mozilla::services::GetObserverService(); obsServ->NotifyObservers(aRequest, "dlp-request-made", nullptr); - rv = RunAnalyzeRequestTask(aRequest, aCallback); + rv = RunAnalyzeRequestTask(aRequest, aAutoAcknowledge, aCallback); return rv; } @@ -880,15 +937,20 @@ ContentAnalysis::CancelContentAnalysisRequest(const nsACString& aRequestToken) { auto callbackMap = self->mCallbackMap.Lock(); auto entry = callbackMap->Lookup(requestToken); LOGD("Content analysis cancelling request %s", requestToken.get()); - if (entry) { - entry->SetCanceled(); + // Make sure the entry hasn't been cancelled already + if (entry && !entry->Canceled()) { RefPtr cancelResponse = ContentAnalysisResponse::FromAction( nsIContentAnalysisResponse::CANCELED, requestToken); cancelResponse->SetOwner(self); nsMainThreadPtrHandle callbackHolder = entry->TakeCallbackHolder(); - callbackHolder->ContentResult(cancelResponse.get()); + entry->SetCanceled(); + // Should only be called once + MOZ_ASSERT(callbackHolder); + if (callbackHolder) { + callbackHolder->ContentResult(cancelResponse.get()); + } } else { LOGD("Content analysis request not found when trying to cancel %s", requestToken.get()); @@ -901,6 +963,11 @@ NS_IMETHODIMP ContentAnalysisResponse::Acknowledge( nsIContentAnalysisAcknowledgement* aAcknowledgement) { MOZ_ASSERT(mOwner); + if (mHasAcknowledged) { + MOZ_ASSERT(false, "Already acknowledged this ContentAnalysisResponse!"); + return NS_ERROR_FAILURE; + } + mHasAcknowledged = true; return mOwner->RunAcknowledgeTask(aAcknowledgement, mRequestToken); }; @@ -926,11 +993,12 @@ nsresult ContentAnalysis::RunAcknowledgeTask( mCaClientPromise->Then( GetCurrentSerialEventTarget(), __func__, [pbAck = std::move(pbAck)]( - std::shared_ptr client) { + std::shared_ptr client) mutable { NS_DispatchBackgroundTask( NS_NewCancelableRunnableFunction( __func__, - [pbAck = std::move(pbAck), client] { + [pbAck = std::move(pbAck), + client = std::move(client)]() mutable { RefPtr owner = GetContentAnalysisFromService(); if (!owner) { diff --git a/toolkit/components/contentanalysis/ContentAnalysis.h b/toolkit/components/contentanalysis/ContentAnalysis.h index b6c57305e823..e9d92db901a9 100644 --- a/toolkit/components/contentanalysis/ContentAnalysis.h +++ b/toolkit/components/contentanalysis/ContentAnalysis.h @@ -18,6 +18,7 @@ namespace content_analysis::sdk { class Client; +class ContentAnalysisRequest; class ContentAnalysisResponse; } // namespace content_analysis::sdk @@ -97,13 +98,18 @@ class ContentAnalysis final : public nsIContentAnalysis { ContentAnalysis& operator=(ContentAnalysis&) = delete; nsresult CreateContentAnalysisClient(nsCString&& aPipePathName, bool aIsPerUser); - nsresult RunAnalyzeRequestTask(RefPtr aRequest, - RefPtr aCallback); + nsresult RunAnalyzeRequestTask( + const RefPtr& aRequest, bool aAutoAcknowledge, + const RefPtr& aCallback); nsresult RunAcknowledgeTask( nsIContentAnalysisAcknowledgement* aAcknowledgement, const nsACString& aRequestToken); nsresult CancelWithError(nsCString aRequestToken, nsresult aResult); static RefPtr GetContentAnalysisFromService(); + static void DoAnalyzeRequest( + nsCString aRequestToken, + content_analysis::sdk::ContentAnalysisRequest&& aRequest, + const std::shared_ptr& aClient); using ClientPromise = MozPromise, nsresult, @@ -112,20 +118,24 @@ class ContentAnalysis final : public nsIContentAnalysis { // Only accessed from the main thread bool mClientCreationAttempted; - class CallbackData { + class CallbackData final { public: - explicit CallbackData( - nsMainThreadPtrHandle&& aCallbackHolder) - : mCallbackHolder(aCallbackHolder) {} + CallbackData( + nsMainThreadPtrHandle&& aCallbackHolder, + bool aAutoAcknowledge) + : mCallbackHolder(aCallbackHolder), + mAutoAcknowledge(aAutoAcknowledge) {} nsMainThreadPtrHandle TakeCallbackHolder() { return std::move(mCallbackHolder); } + bool AutoAcknowledge() const { return mAutoAcknowledge; } void SetCanceled() { mCallbackHolder = nullptr; } bool Canceled() const { return !mCallbackHolder; } private: nsMainThreadPtrHandle mCallbackHolder; + bool mAutoAcknowledge; }; DataMutex> mCallbackMap; friend class ContentAnalysisResponse; @@ -165,9 +175,28 @@ class ContentAnalysisResponse final : public nsIContentAnalysisResponse { // the transaction. RefPtr mOwner; + // Whether the response has been acknowledged + bool mHasAcknowledged; + friend class ContentAnalysis; }; +class ContentAnalysisAcknowledgement final + : public nsIContentAnalysisAcknowledgement { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSICONTENTANALYSISACKNOWLEDGEMENT + + ContentAnalysisAcknowledgement(unsigned long aResult, + unsigned long aFinalAction); + + private: + ~ContentAnalysisAcknowledgement() = default; + + unsigned long mResult; + unsigned long mFinalAction; +}; + class ContentAnalysisCallback final : public nsIContentAnalysisCallback { public: NS_DECL_THREADSAFE_ISUPPORTS diff --git a/toolkit/components/contentanalysis/nsIContentAnalysis.idl b/toolkit/components/contentanalysis/nsIContentAnalysis.idl index 1484586f4a7c..263e1b936376 100644 --- a/toolkit/components/contentanalysis/nsIContentAnalysis.idl +++ b/toolkit/components/contentanalysis/nsIContentAnalysis.idl @@ -56,8 +56,9 @@ interface nsIContentAnalysisResponse : nsISupports /** * Acknowledge receipt of an analysis response. - * Should always be called after successful resolution of the promise - * from AnalyzeContentRequest. + * If false is passed for aAutoAcknowledge to AnalyzeContentRequest, + * the caller is responsible for calling this after successful + * resolution of the promise. */ void acknowledge(in nsIContentAnalysisAcknowledgement aCaa); }; @@ -185,22 +186,39 @@ interface nsIContentAnalysis : nsISupports * The resulting Promise resolves to a nsIContentAnalysisResponse, * which may take some time to get from the analysis server. It will * be rejected, with an string error description, if any error occurs. - * A successful nsIContentAnalysisResponse must be acknowledged. - * See @nsIContentAnalysisResponse. + * + * @param aCar + * The request to analyze. + * @param aAutoAcknowledge + * Whether to send an acknowledge message to the agent after the agent sends a response. + * Passing false means that the caller is responsible for + * calling nsIContentAnalysisResponse::acknowledge(), unless the request is cancelled. */ [implicit_jscontext] - Promise analyzeContentRequest(in nsIContentAnalysisRequest aCar); + Promise analyzeContentRequest(in nsIContentAnalysisRequest aCar, in bool aAutoAcknowledge); /** * Same functionality as AnalyzeContentRequest(), but more convenient to call * from C++ since it takes a callback instead of returning a Promise. + * + * @param aCar + * The request to analyze. + * @param aAutoAcknowledge + * Whether to send an acknowledge message to the agent after the agent sends a response. + * Passing false means that the caller is responsible for + * calling nsIContentAnalysisResponse::acknowledge(), unless the request is cancelled. + * @param callback + * Callbacks to be called when the agent sends a response message (or when there is an error). */ - void analyzeContentRequestCallback(in nsIContentAnalysisRequest aCar, in nsIContentAnalysisCallback callback); + void analyzeContentRequestCallback(in nsIContentAnalysisRequest aCar, in bool aAutoAcknowledge, in nsIContentAnalysisCallback callback); /** * Cancels the request that is in progress. This may not actually cancel the request * with the analysis server, but it means that Gecko will immediately act like the request * was denied. + * + * @param aRequestToken + * The token for the request to cancel. */ void cancelContentAnalysisRequest(in ACString aRequestToken); };