diff options
Diffstat (limited to 'services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java')
-rw-r--r-- | services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java | 108 |
1 files changed, 69 insertions, 39 deletions
diff --git a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java index 5f00148335a7..89a530514263 100644 --- a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java +++ b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java @@ -30,6 +30,7 @@ import android.os.UserHandle; import android.service.textclassifier.ITextClassifierCallback; import android.service.textclassifier.ITextClassifierService; import android.service.textclassifier.TextClassifierService; +import android.util.ArrayMap; import android.util.Slog; import android.util.SparseArray; import android.view.textclassifier.ConversationActions; @@ -54,6 +55,7 @@ import com.android.server.SystemService; import java.io.FileDescriptor; import java.io.PrintWriter; import java.util.ArrayDeque; +import java.util.Map; import java.util.Queue; /** @@ -119,6 +121,8 @@ public final class TextClassificationManagerService extends ITextClassifierServi private final Object mLock; @GuardedBy("mLock") final SparseArray<UserState> mUserStates = new SparseArray<>(); + @GuardedBy("mLock") + private final Map<TextClassificationSessionId, Integer> mSessionUserIds = new ArrayMap<>(); private TextClassificationManagerService(Context context) { mContext = Preconditions.checkNotNull(context); @@ -127,15 +131,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi @Override public void onSuggestSelection( - TextClassificationSessionId sessionId, + @Nullable TextClassificationSessionId sessionId, TextSelection.Request request, ITextClassifierCallback callback) throws RemoteException { Preconditions.checkNotNull(request); Preconditions.checkNotNull(callback); - validateInput(mContext, request.getCallingPackageName()); + final int userId = request.getUserId(); + validateInput(mContext, request.getCallingPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (!userState.bindLocked()) { callback.onFailure(); } else if (userState.isBoundLocked()) { @@ -150,15 +155,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi @Override public void onClassifyText( - TextClassificationSessionId sessionId, + @Nullable TextClassificationSessionId sessionId, TextClassification.Request request, ITextClassifierCallback callback) throws RemoteException { Preconditions.checkNotNull(request); Preconditions.checkNotNull(callback); - validateInput(mContext, request.getCallingPackageName()); + final int userId = request.getUserId(); + validateInput(mContext, request.getCallingPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (!userState.bindLocked()) { callback.onFailure(); } else if (userState.isBoundLocked()) { @@ -173,15 +179,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi @Override public void onGenerateLinks( - TextClassificationSessionId sessionId, + @Nullable TextClassificationSessionId sessionId, TextLinks.Request request, ITextClassifierCallback callback) throws RemoteException { Preconditions.checkNotNull(request); Preconditions.checkNotNull(callback); - validateInput(mContext, request.getCallingPackageName()); + final int userId = request.getUserId(); + validateInput(mContext, request.getCallingPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (!userState.bindLocked()) { callback.onFailure(); } else if (userState.isBoundLocked()) { @@ -196,12 +203,14 @@ public final class TextClassificationManagerService extends ITextClassifierServi @Override public void onSelectionEvent( - TextClassificationSessionId sessionId, SelectionEvent event) throws RemoteException { + @Nullable TextClassificationSessionId sessionId, SelectionEvent event) + throws RemoteException { Preconditions.checkNotNull(event); - validateInput(mContext, event.getPackageName()); + final int userId = event.getUserId(); + validateInput(mContext, event.getPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (userState.isBoundLocked()) { userState.mService.onSelectionEvent(sessionId, event); } else { @@ -213,16 +222,19 @@ public final class TextClassificationManagerService extends ITextClassifierServi } @Override public void onTextClassifierEvent( - TextClassificationSessionId sessionId, + @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) throws RemoteException { Preconditions.checkNotNull(event); final String packageName = event.getEventContext() == null ? null : event.getEventContext().getPackageName(); - validateInput(mContext, packageName); + final int userId = event.getEventContext() == null + ? UserHandle.getCallingUserId() + : event.getEventContext().getUserId(); + validateInput(mContext, packageName, userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (userState.isBoundLocked()) { userState.mService.onTextClassifierEvent(sessionId, event); } else { @@ -235,15 +247,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi @Override public void onDetectLanguage( - TextClassificationSessionId sessionId, + @Nullable TextClassificationSessionId sessionId, TextLanguage.Request request, ITextClassifierCallback callback) throws RemoteException { Preconditions.checkNotNull(request); Preconditions.checkNotNull(callback); - validateInput(mContext, request.getCallingPackageName()); + final int userId = request.getUserId(); + validateInput(mContext, request.getCallingPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (!userState.bindLocked()) { callback.onFailure(); } else if (userState.isBoundLocked()) { @@ -258,15 +271,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi @Override public void onSuggestConversationActions( - TextClassificationSessionId sessionId, + @Nullable TextClassificationSessionId sessionId, ConversationActions.Request request, ITextClassifierCallback callback) throws RemoteException { Preconditions.checkNotNull(request); Preconditions.checkNotNull(callback); - validateInput(mContext, request.getCallingPackageName()); + final int userId = request.getUserId(); + validateInput(mContext, request.getCallingPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (!userState.bindLocked()) { callback.onFailure(); } else if (userState.isBoundLocked()) { @@ -285,13 +299,15 @@ public final class TextClassificationManagerService extends ITextClassifierServi throws RemoteException { Preconditions.checkNotNull(sessionId); Preconditions.checkNotNull(classificationContext); - validateInput(mContext, classificationContext.getPackageName()); + final int userId = classificationContext.getUserId(); + validateInput(mContext, classificationContext.getPackageName(), userId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + UserState userState = getUserStateLocked(userId); if (userState.isBoundLocked()) { userState.mService.onCreateTextClassificationSession( classificationContext, sessionId); + mSessionUserIds.put(sessionId, userId); } else { userState.mPendingRequests.add(new PendingRequest( () -> onCreateTextClassificationSession(classificationContext, sessionId), @@ -306,9 +322,15 @@ public final class TextClassificationManagerService extends ITextClassifierServi Preconditions.checkNotNull(sessionId); synchronized (mLock) { - UserState userState = getCallingUserStateLocked(); + final int userId = mSessionUserIds.containsKey(sessionId) + ? mSessionUserIds.get(sessionId) + : UserHandle.getCallingUserId(); + validateInput(mContext, null /* packageName */, userId); + + UserState userState = getUserStateLocked(userId); if (userState.isBoundLocked()) { userState.mService.onDestroyTextClassificationSession(sessionId); + mSessionUserIds.remove(sessionId); } else { userState.mPendingRequests.add(new PendingRequest( () -> onDestroyTextClassificationSession(sessionId), @@ -318,11 +340,6 @@ public final class TextClassificationManagerService extends ITextClassifierServi } @GuardedBy("mLock") - private UserState getCallingUserStateLocked() { - return getUserStateLocked(UserHandle.getCallingUserId()); - } - - @GuardedBy("mLock") private UserState getUserStateLocked(int userId) { UserState result = mUserStates.get(userId); if (result == null) { @@ -356,6 +373,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi pw.decreaseIndent(); } } + pw.println("Number of active sessions: " + mSessionUserIds.size()); } } @@ -420,20 +438,32 @@ public final class TextClassificationManagerService extends ITextClassifierServi e -> Slog.d(LOG_TAG, "Error " + opDesc + ": " + e.getMessage())); } - private static void validateInput(Context context, @Nullable String packageName) + private static void validateInput( + Context context, @Nullable String packageName, @UserIdInt int userId) throws RemoteException { - if (packageName == null) return; try { - final int packageUid = context.getPackageManager() - .getPackageUidAsUser(packageName, UserHandle.getCallingUserId()); - final int callingUid = Binder.getCallingUid(); - Preconditions.checkArgument(callingUid == packageUid - // Trust the system process: - || callingUid == android.os.Process.SYSTEM_UID); + if (packageName != null) { + final int packageUid = context.getPackageManager() + .getPackageUidAsUser(packageName, UserHandle.getCallingUserId()); + final int callingUid = Binder.getCallingUid(); + Preconditions.checkArgument(callingUid == packageUid + // Trust the system process: + || callingUid == android.os.Process.SYSTEM_UID, + "Invalid package name. Package=" + packageName + + ", CallingUid=" + callingUid); + } + + Preconditions.checkArgument(userId != UserHandle.USER_NULL, "Null userId"); + final int callingUserId = UserHandle.getCallingUserId(); + if (callingUserId != userId) { + context.enforceCallingOrSelfPermission( + android.Manifest.permission.INTERACT_ACROSS_USERS_FULL, + "Invalid userId. UserId=" + userId + ", CallingUserId=" + callingUserId); + } } catch (Exception e) { - throw new RemoteException( - String.format("Invalid package: name=%s, error=%s", packageName, e)); + throw new RemoteException("Invalid request: " + e.getMessage(), e, + /* enableSuppression */ true, /* writableStackTrace */ true); } } |