summaryrefslogtreecommitdiff
path: root/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
diff options
context:
space:
mode:
Diffstat (limited to 'services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java')
-rw-r--r--services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java108
1 files changed, 69 insertions, 39 deletions
diff --git a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
index 5f00148335a7..89a530514263 100644
--- a/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
+++ b/services/core/java/com/android/server/textclassifier/TextClassificationManagerService.java
@@ -30,6 +30,7 @@ import android.os.UserHandle;
import android.service.textclassifier.ITextClassifierCallback;
import android.service.textclassifier.ITextClassifierService;
import android.service.textclassifier.TextClassifierService;
+import android.util.ArrayMap;
import android.util.Slog;
import android.util.SparseArray;
import android.view.textclassifier.ConversationActions;
@@ -54,6 +55,7 @@ import com.android.server.SystemService;
import java.io.FileDescriptor;
import java.io.PrintWriter;
import java.util.ArrayDeque;
+import java.util.Map;
import java.util.Queue;
/**
@@ -119,6 +121,8 @@ public final class TextClassificationManagerService extends ITextClassifierServi
private final Object mLock;
@GuardedBy("mLock")
final SparseArray<UserState> mUserStates = new SparseArray<>();
+ @GuardedBy("mLock")
+ private final Map<TextClassificationSessionId, Integer> mSessionUserIds = new ArrayMap<>();
private TextClassificationManagerService(Context context) {
mContext = Preconditions.checkNotNull(context);
@@ -127,15 +131,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onSuggestSelection(
- TextClassificationSessionId sessionId,
+ @Nullable TextClassificationSessionId sessionId,
TextSelection.Request request, ITextClassifierCallback callback)
throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
- validateInput(mContext, request.getCallingPackageName());
+ final int userId = request.getUserId();
+ validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -150,15 +155,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onClassifyText(
- TextClassificationSessionId sessionId,
+ @Nullable TextClassificationSessionId sessionId,
TextClassification.Request request, ITextClassifierCallback callback)
throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
- validateInput(mContext, request.getCallingPackageName());
+ final int userId = request.getUserId();
+ validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -173,15 +179,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onGenerateLinks(
- TextClassificationSessionId sessionId,
+ @Nullable TextClassificationSessionId sessionId,
TextLinks.Request request, ITextClassifierCallback callback)
throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
- validateInput(mContext, request.getCallingPackageName());
+ final int userId = request.getUserId();
+ validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -196,12 +203,14 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onSelectionEvent(
- TextClassificationSessionId sessionId, SelectionEvent event) throws RemoteException {
+ @Nullable TextClassificationSessionId sessionId, SelectionEvent event)
+ throws RemoteException {
Preconditions.checkNotNull(event);
- validateInput(mContext, event.getPackageName());
+ final int userId = event.getUserId();
+ validateInput(mContext, event.getPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onSelectionEvent(sessionId, event);
} else {
@@ -213,16 +222,19 @@ public final class TextClassificationManagerService extends ITextClassifierServi
}
@Override
public void onTextClassifierEvent(
- TextClassificationSessionId sessionId,
+ @Nullable TextClassificationSessionId sessionId,
TextClassifierEvent event) throws RemoteException {
Preconditions.checkNotNull(event);
final String packageName = event.getEventContext() == null
? null
: event.getEventContext().getPackageName();
- validateInput(mContext, packageName);
+ final int userId = event.getEventContext() == null
+ ? UserHandle.getCallingUserId()
+ : event.getEventContext().getUserId();
+ validateInput(mContext, packageName, userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onTextClassifierEvent(sessionId, event);
} else {
@@ -235,15 +247,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onDetectLanguage(
- TextClassificationSessionId sessionId,
+ @Nullable TextClassificationSessionId sessionId,
TextLanguage.Request request,
ITextClassifierCallback callback) throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
- validateInput(mContext, request.getCallingPackageName());
+ final int userId = request.getUserId();
+ validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -258,15 +271,16 @@ public final class TextClassificationManagerService extends ITextClassifierServi
@Override
public void onSuggestConversationActions(
- TextClassificationSessionId sessionId,
+ @Nullable TextClassificationSessionId sessionId,
ConversationActions.Request request,
ITextClassifierCallback callback) throws RemoteException {
Preconditions.checkNotNull(request);
Preconditions.checkNotNull(callback);
- validateInput(mContext, request.getCallingPackageName());
+ final int userId = request.getUserId();
+ validateInput(mContext, request.getCallingPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (!userState.bindLocked()) {
callback.onFailure();
} else if (userState.isBoundLocked()) {
@@ -285,13 +299,15 @@ public final class TextClassificationManagerService extends ITextClassifierServi
throws RemoteException {
Preconditions.checkNotNull(sessionId);
Preconditions.checkNotNull(classificationContext);
- validateInput(mContext, classificationContext.getPackageName());
+ final int userId = classificationContext.getUserId();
+ validateInput(mContext, classificationContext.getPackageName(), userId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onCreateTextClassificationSession(
classificationContext, sessionId);
+ mSessionUserIds.put(sessionId, userId);
} else {
userState.mPendingRequests.add(new PendingRequest(
() -> onCreateTextClassificationSession(classificationContext, sessionId),
@@ -306,9 +322,15 @@ public final class TextClassificationManagerService extends ITextClassifierServi
Preconditions.checkNotNull(sessionId);
synchronized (mLock) {
- UserState userState = getCallingUserStateLocked();
+ final int userId = mSessionUserIds.containsKey(sessionId)
+ ? mSessionUserIds.get(sessionId)
+ : UserHandle.getCallingUserId();
+ validateInput(mContext, null /* packageName */, userId);
+
+ UserState userState = getUserStateLocked(userId);
if (userState.isBoundLocked()) {
userState.mService.onDestroyTextClassificationSession(sessionId);
+ mSessionUserIds.remove(sessionId);
} else {
userState.mPendingRequests.add(new PendingRequest(
() -> onDestroyTextClassificationSession(sessionId),
@@ -318,11 +340,6 @@ public final class TextClassificationManagerService extends ITextClassifierServi
}
@GuardedBy("mLock")
- private UserState getCallingUserStateLocked() {
- return getUserStateLocked(UserHandle.getCallingUserId());
- }
-
- @GuardedBy("mLock")
private UserState getUserStateLocked(int userId) {
UserState result = mUserStates.get(userId);
if (result == null) {
@@ -356,6 +373,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
pw.decreaseIndent();
}
}
+ pw.println("Number of active sessions: " + mSessionUserIds.size());
}
}
@@ -420,20 +438,32 @@ public final class TextClassificationManagerService extends ITextClassifierServi
e -> Slog.d(LOG_TAG, "Error " + opDesc + ": " + e.getMessage()));
}
- private static void validateInput(Context context, @Nullable String packageName)
+ private static void validateInput(
+ Context context, @Nullable String packageName, @UserIdInt int userId)
throws RemoteException {
- if (packageName == null) return;
try {
- final int packageUid = context.getPackageManager()
- .getPackageUidAsUser(packageName, UserHandle.getCallingUserId());
- final int callingUid = Binder.getCallingUid();
- Preconditions.checkArgument(callingUid == packageUid
- // Trust the system process:
- || callingUid == android.os.Process.SYSTEM_UID);
+ if (packageName != null) {
+ final int packageUid = context.getPackageManager()
+ .getPackageUidAsUser(packageName, UserHandle.getCallingUserId());
+ final int callingUid = Binder.getCallingUid();
+ Preconditions.checkArgument(callingUid == packageUid
+ // Trust the system process:
+ || callingUid == android.os.Process.SYSTEM_UID,
+ "Invalid package name. Package=" + packageName
+ + ", CallingUid=" + callingUid);
+ }
+
+ Preconditions.checkArgument(userId != UserHandle.USER_NULL, "Null userId");
+ final int callingUserId = UserHandle.getCallingUserId();
+ if (callingUserId != userId) {
+ context.enforceCallingOrSelfPermission(
+ android.Manifest.permission.INTERACT_ACROSS_USERS_FULL,
+ "Invalid userId. UserId=" + userId + ", CallingUserId=" + callingUserId);
+ }
} catch (Exception e) {
- throw new RemoteException(
- String.format("Invalid package: name=%s, error=%s", packageName, e));
+ throw new RemoteException("Invalid request: " + e.getMessage(), e,
+ /* enableSuppression */ true, /* writableStackTrace */ true);
}
}