Fix race condition for VQDS functionalities

Lock are missing for thread safety on major VQDS functions.
These methods can be invoked consecutively and race condition can
happen. We need to add a lock to prevent it from happening.

Bug: 295390470
Test: atest CtsVoiceInteractionTestCases
Flag: N/A
Change-Id: I99c43deb024f2d5fee5c1ac12e549966f949dc1e
This commit is contained in:
charleschen 2023-12-14 02:08:03 +00:00
parent 6fafa0de38
commit f9835510ba
2 changed files with 133 additions and 91 deletions

View File

@ -94,7 +94,9 @@ public class VisualQueryDetector {
*/ */
public void updateState(@Nullable PersistableBundle options, public void updateState(@Nullable PersistableBundle options,
@Nullable SharedMemory sharedMemory) { @Nullable SharedMemory sharedMemory) {
mInitializationDelegate.updateState(options, sharedMemory); synchronized (mInitializationDelegate.getLock()) {
mInitializationDelegate.updateState(options, sharedMemory);
}
} }
@ -116,18 +118,21 @@ public class VisualQueryDetector {
if (DEBUG) { if (DEBUG) {
Slog.i(TAG, "#startRecognition"); Slog.i(TAG, "#startRecognition");
} }
// check if the detector is active with the initialization delegate synchronized (mInitializationDelegate.getLock()) {
mInitializationDelegate.startRecognition(); // check if the detector is active with the initialization delegate
mInitializationDelegate.startRecognition();
try { try {
mManagerService.startPerceiving(new BinderCallback(mExecutor, mCallback)); mManagerService.startPerceiving(new BinderCallback(
} catch (SecurityException e) { mExecutor, mCallback, mInitializationDelegate.getLock()));
Slog.e(TAG, "startRecognition failed: " + e); } catch (SecurityException e) {
return false; Slog.e(TAG, "startRecognition failed: " + e);
} catch (RemoteException e) { return false;
e.rethrowFromSystemServer(); } catch (RemoteException e) {
e.rethrowFromSystemServer();
}
return true;
} }
return true;
} }
/** /**
@ -140,15 +145,17 @@ public class VisualQueryDetector {
if (DEBUG) { if (DEBUG) {
Slog.i(TAG, "#stopRecognition"); Slog.i(TAG, "#stopRecognition");
} }
// check if the detector is active with the initialization delegate synchronized (mInitializationDelegate.getLock()) {
mInitializationDelegate.startRecognition(); // check if the detector is active with the initialization delegate
mInitializationDelegate.stopRecognition();
try { try {
mManagerService.stopPerceiving(); mManagerService.stopPerceiving();
} catch (RemoteException e) { } catch (RemoteException e) {
e.rethrowFromSystemServer(); e.rethrowFromSystemServer();
}
return true;
} }
return true;
} }
/** /**
@ -160,12 +167,16 @@ public class VisualQueryDetector {
if (DEBUG) { if (DEBUG) {
Slog.i(TAG, "#destroy"); Slog.i(TAG, "#destroy");
} }
mInitializationDelegate.destroy(); synchronized (mInitializationDelegate.getLock()) {
mInitializationDelegate.destroy();
}
} }
/** @hide */ /** @hide */
public void dump(String prefix, PrintWriter pw) { public void dump(String prefix, PrintWriter pw) {
// TODO: implement this synchronized (mInitializationDelegate.getLock()) {
mInitializationDelegate.dump(prefix, pw);
}
} }
/** @hide */ /** @hide */
@ -175,7 +186,9 @@ public class VisualQueryDetector {
/** @hide */ /** @hide */
void registerOnDestroyListener(Consumer<AbstractDetector> onDestroyListener) { void registerOnDestroyListener(Consumer<AbstractDetector> onDestroyListener) {
mInitializationDelegate.registerOnDestroyListener(onDestroyListener); synchronized (mInitializationDelegate.getLock()) {
mInitializationDelegate.registerOnDestroyListener(onDestroyListener);
}
} }
/** /**
@ -282,6 +295,15 @@ public class VisualQueryDetector {
public boolean isUsingSandboxedDetectionService() { public boolean isUsingSandboxedDetectionService() {
return true; return true;
} }
@Override
public void dump(String prefix, PrintWriter pw) {
// No-op
}
private Object getLock() {
return mLock;
}
} }
private static class BinderCallback private static class BinderCallback
@ -289,31 +311,43 @@ public class VisualQueryDetector {
private final Executor mExecutor; private final Executor mExecutor;
private final VisualQueryDetector.Callback mCallback; private final VisualQueryDetector.Callback mCallback;
BinderCallback(Executor executor, VisualQueryDetector.Callback callback) { private final Object mLock;
BinderCallback(Executor executor, VisualQueryDetector.Callback callback, Object lock) {
this.mExecutor = executor; this.mExecutor = executor;
this.mCallback = callback; this.mCallback = callback;
this.mLock = lock;
} }
/** Called when the detected result is valid. */ /** Called when the detected result is valid. */
@Override @Override
public void onQueryDetected(@NonNull String partialQuery) { public void onQueryDetected(@NonNull String partialQuery) {
Slog.v(TAG, "BinderCallback#onQueryDetected"); Slog.v(TAG, "BinderCallback#onQueryDetected");
Binder.withCleanCallingIdentity(() -> mExecutor.execute( Binder.withCleanCallingIdentity(() -> {
() -> mCallback.onQueryDetected(partialQuery))); synchronized (mLock) {
mCallback.onQueryDetected(partialQuery);
}
});
} }
@Override @Override
public void onQueryFinished() { public void onQueryFinished() {
Slog.v(TAG, "BinderCallback#onQueryFinished"); Slog.v(TAG, "BinderCallback#onQueryFinished");
Binder.withCleanCallingIdentity(() -> mExecutor.execute( Binder.withCleanCallingIdentity(() -> {
() -> mCallback.onQueryFinished())); synchronized (mLock) {
mCallback.onQueryFinished();
}
});
} }
@Override @Override
public void onQueryRejected() { public void onQueryRejected() {
Slog.v(TAG, "BinderCallback#onQueryRejected"); Slog.v(TAG, "BinderCallback#onQueryRejected");
Binder.withCleanCallingIdentity(() -> mExecutor.execute( Binder.withCleanCallingIdentity(() -> {
() -> mCallback.onQueryRejected())); synchronized (mLock) {
mCallback.onQueryRejected();
}
});
} }
/** Called when the detection fails due to an error. */ /** Called when the detection fails due to an error. */

View File

@ -106,96 +106,104 @@ final class VisualQueryDetectorSession extends DetectorSession {
@Override @Override
public void onAttentionGained() { public void onAttentionGained() {
Slog.v(TAG, "BinderCallback#onAttentionGained"); Slog.v(TAG, "BinderCallback#onAttentionGained");
mEgressingData = true; synchronized (mLock) {
if (mAttentionListener == null) { mEgressingData = true;
return; if (mAttentionListener == null) {
} return;
try { }
mAttentionListener.onAttentionGained(); try {
} catch (RemoteException e) { mAttentionListener.onAttentionGained();
Slog.e(TAG, "Error delivering attention gained event.", e); } catch (RemoteException e) {
try { Slog.e(TAG, "Error delivering attention gained event.", e);
callback.onVisualQueryDetectionServiceFailure( try {
new VisualQueryDetectionServiceFailure( callback.onVisualQueryDetectionServiceFailure(
ERROR_CODE_ILLEGAL_ATTENTION_STATE, new VisualQueryDetectionServiceFailure(
"Attention listener failed to switch to GAINED state.")); ERROR_CODE_ILLEGAL_ATTENTION_STATE,
} catch (RemoteException ex) { "Attention listener fails to switch to GAINED state."));
Slog.v(TAG, "Fail to call onVisualQueryDetectionServiceFailure"); } catch (RemoteException ex) {
Slog.v(TAG, "Fail to call onVisualQueryDetectionServiceFailure");
}
} }
return;
} }
} }
@Override @Override
public void onAttentionLost() { public void onAttentionLost() {
Slog.v(TAG, "BinderCallback#onAttentionLost"); Slog.v(TAG, "BinderCallback#onAttentionLost");
mEgressingData = false; synchronized (mLock) {
if (mAttentionListener == null) { mEgressingData = false;
return; if (mAttentionListener == null) {
} return;
try { }
mAttentionListener.onAttentionLost(); try {
} catch (RemoteException e) { mAttentionListener.onAttentionLost();
Slog.e(TAG, "Error delivering attention lost event.", e); } catch (RemoteException e) {
try { Slog.e(TAG, "Error delivering attention lost event.", e);
callback.onVisualQueryDetectionServiceFailure( try {
new VisualQueryDetectionServiceFailure( callback.onVisualQueryDetectionServiceFailure(
ERROR_CODE_ILLEGAL_ATTENTION_STATE, new VisualQueryDetectionServiceFailure(
"Attention listener failed to switch to LOST state.")); ERROR_CODE_ILLEGAL_ATTENTION_STATE,
} catch (RemoteException ex) { "Attention listener fails to switch to LOST state."));
Slog.v(TAG, "Fail to call onVisualQueryDetectionServiceFailure"); } catch (RemoteException ex) {
Slog.v(TAG, "Fail to call onVisualQueryDetectionServiceFailure");
}
} }
return;
} }
} }
@Override @Override
public void onQueryDetected(@NonNull String partialQuery) throws RemoteException { public void onQueryDetected(@NonNull String partialQuery) throws RemoteException {
Objects.requireNonNull(partialQuery);
Slog.v(TAG, "BinderCallback#onQueryDetected"); Slog.v(TAG, "BinderCallback#onQueryDetected");
if (!mEgressingData) { synchronized (mLock) {
Slog.v(TAG, "Query should not be egressed within the unattention state."); Objects.requireNonNull(partialQuery);
callback.onVisualQueryDetectionServiceFailure( if (!mEgressingData) {
new VisualQueryDetectionServiceFailure( Slog.v(TAG, "Query should not be egressed within the unattention state.");
ERROR_CODE_ILLEGAL_STREAMING_STATE, callback.onVisualQueryDetectionServiceFailure(
"Cannot stream queries without attention signals.")); new VisualQueryDetectionServiceFailure(
return; ERROR_CODE_ILLEGAL_STREAMING_STATE,
"Cannot stream queries without attention signals."));
return;
}
mQueryStreaming = true;
callback.onQueryDetected(partialQuery);
Slog.i(TAG, "Egressed from visual query detection process.");
} }
mQueryStreaming = true;
callback.onQueryDetected(partialQuery);
Slog.i(TAG, "Egressed from visual query detection process.");
} }
@Override @Override
public void onQueryFinished() throws RemoteException { public void onQueryFinished() throws RemoteException {
Slog.v(TAG, "BinderCallback#onQueryFinished"); Slog.v(TAG, "BinderCallback#onQueryFinished");
if (!mQueryStreaming) { synchronized (mLock) {
Slog.v(TAG, "Query streaming state signal FINISHED is block since there is" if (!mQueryStreaming) {
+ " no active query being streamed."); Slog.v(TAG, "Query streaming state signal FINISHED is block since there is"
callback.onVisualQueryDetectionServiceFailure( + " no active query being streamed.");
new VisualQueryDetectionServiceFailure( callback.onVisualQueryDetectionServiceFailure(
ERROR_CODE_ILLEGAL_STREAMING_STATE, new VisualQueryDetectionServiceFailure(
"Cannot send FINISHED signal with no query streamed.")); ERROR_CODE_ILLEGAL_STREAMING_STATE,
return; "Cannot send FINISHED signal with no query streamed."));
return;
}
callback.onQueryFinished();
mQueryStreaming = false;
} }
callback.onQueryFinished();
mQueryStreaming = false;
} }
@Override @Override
public void onQueryRejected() throws RemoteException { public void onQueryRejected() throws RemoteException {
Slog.v(TAG, "BinderCallback#onQueryRejected"); Slog.v(TAG, "BinderCallback#onQueryRejected");
if (!mQueryStreaming) { synchronized (mLock) {
Slog.v(TAG, "Query streaming state signal REJECTED is block since there is" if (!mQueryStreaming) {
+ " no active query being streamed."); Slog.v(TAG, "Query streaming state signal REJECTED is block since there is"
callback.onVisualQueryDetectionServiceFailure( + " no active query being streamed.");
new VisualQueryDetectionServiceFailure( callback.onVisualQueryDetectionServiceFailure(
ERROR_CODE_ILLEGAL_STREAMING_STATE, new VisualQueryDetectionServiceFailure(
"Cannot send REJECTED signal with no query streamed.")); ERROR_CODE_ILLEGAL_STREAMING_STATE,
return; "Cannot send REJECTED signal with no query streamed."));
return;
}
callback.onQueryRejected();
mQueryStreaming = false;
} }
callback.onQueryRejected();
mQueryStreaming = false;
} }
}; };
return mRemoteDetectionService.run( return mRemoteDetectionService.run(