summaryrefslogtreecommitdiff
path: root/include/input/MotionPredictorMetricsManager.h
blob: 38472d8df7140dd8384cc543d3007a02d0ae9a9b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
/*
 * Copyright 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <vector>

#include <input/Input.h> // for MotionEvent
#include <input/RingBuffer.h>
#include <utils/Timers.h> // for nsecs_t

#include "Eigen/Core"

namespace android {

/**
 * Class to handle computing and reporting metrics for MotionPredictor.
 *
 * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
 * MotionEvents from the corresponding methods in MotionPredictor.
 *
 * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
 * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
 * AtomFields are computed and reported to the stats library. The number of atoms reported is equal
 * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one
 * "prediction time bucket" — the amount of time into the future being predicted.
 *
 * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
 * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
 */
class MotionPredictorMetricsManager {
public:
    struct AtomFields;

    using ReportAtomFunction = std::function<void(const AtomFields&)>;

    static void defaultReportAtomFunction(const AtomFields& atomFields);

    // Parameters:
    //  • predictionInterval: the time interval between successive prediction target timestamps.
    //    Note: the MetricsManager assumes that the input interval equals the prediction interval.
    //  • maxNumPredictions: the maximum number of distinct target timestamps the prediction model
    //    will generate predictions for. The MetricsManager reports this many atoms per stroke.
    //  • [Optional] reportAtomFunction: the function that will be called to report metrics. If
    //    omitted (or if an empty function is given), the `stats_write(…)` function from the Android
    //    stats library will be used.
    MotionPredictorMetricsManager(
            nsecs_t predictionInterval,
            size_t maxNumPredictions,
            ReportAtomFunction reportAtomFunction = defaultReportAtomFunction);

    // This method should be called once for each call to MotionPredictor::record, receiving the
    // forwarded MotionEvent argument.
    void onRecord(const MotionEvent& inputEvent);

    // This method should be called once for each call to MotionPredictor::predict, receiving the
    // MotionEvent that will be returned by MotionPredictor::predict.
    void onPredict(const MotionEvent& predictionEvent);

    // Simple structs to hold relevant touch input information. Public so they can be used in tests.

    struct TouchPoint {
        Eigen::Vector2f position; // (y, x) in pixels
        float pressure;
    };

    struct GroundTruthPoint : TouchPoint {
        nsecs_t timestamp;
    };

    struct PredictionPoint : TouchPoint {
        // The timestamp of the last ground truth point when the prediction was made.
        nsecs_t originTimestamp;

        nsecs_t targetTimestamp;

        // Order by targetTimestamp when sorting.
        bool operator<(const PredictionPoint& other) const {
            return this->targetTimestamp < other.targetTimestamp;
        }
    };

    // Metrics aggregated so far for the current stroke. These are not the final fields to be
    // reported in the atom (see AtomFields below), but rather an intermediate representation of the
    // data that can be conveniently aggregated and from which the atom fields can be derived later.
    //
    // Displacement units are in pixels.
    //
    // "Along-trajectory error" is the dot product of the prediction error with the unit vector
    // pointing towards the ground truth point whose timestamp corresponds to the prediction
    // target timestamp, originating from the preceding ground truth point.
    //
    // "Off-trajectory error" is the component of the prediction error orthogonal to the
    // "along-trajectory" unit vector described above.
    //
    // "High-velocity" errors are errors that are only accumulated when the velocity between the
    // most recent two input events exceeds a certain threshold.
    //
    // "Scale-invariant errors" are the errors produced when the path length of the stroke is
    // scaled to 1. (In other words, the error distances are normalized by the path length.)
    struct AggregatedStrokeMetrics {
        // General errors
        float alongTrajectoryErrorSum = 0;
        float alongTrajectorySumSquaredErrors = 0;
        float offTrajectorySumSquaredErrors = 0;
        float pressureSumSquaredErrors = 0;
        size_t generalErrorsCount = 0;

        // High-velocity errors
        float highVelocityAlongTrajectorySse = 0;
        float highVelocityOffTrajectorySse = 0;
        size_t highVelocityErrorsCount = 0;

        // Scale-invariant errors
        float scaleInvariantAlongTrajectorySse = 0;
        float scaleInvariantOffTrajectorySse = 0;
        size_t scaleInvariantErrorsCount = 0;
    };

    // In order to explicitly indicate "no relevant data" for a metric, we report this
    // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
    // completely unobtainable. For along-trajectory error mean, which can be negative, the
    // magnitude makes it unobtainable in practice.)
    static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();

    // Final metric values reported in the atom.
    struct AtomFields {
        int deltaTimeBucketMilliseconds = 0;

        // General errors
        int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
        int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
        int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
        int pressureRmseMilliunits = NO_DATA_SENTINEL;

        // High-velocity errors
        int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
        int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels

        // Scale-invariant errors
        int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
        int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels
    };

private:
    // The interval between consecutive predictions' target timestamps. We assume that the input
    // interval also equals this value.
    const nsecs_t mPredictionInterval;

    // The maximum number of input frames into the future the model can predict.
    // Used to perform time-bucketing of metrics.
    const size_t mMaxNumPredictions;

    // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
    // error. (Also, the last two points are used to compute the ground truth trajectory.)
    RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;

    // Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
    // Invariant: sorted in ascending order of targetTimestamp.
    std::vector<PredictionPoint> mRecentPredictions;

    // Containers for the intermediate representation of stroke metrics and the final atom fields.
    // These are indexed by the number of input frames into the future being predicted minus one,
    // and always have size mMaxNumPredictions.
    std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
    std::vector<AtomFields> mAtomFields;

    const ReportAtomFunction mReportAtomFunction;

    // Helper methods for the implementation of onRecord and onPredict.

    // Clears stored ground truth and prediction points, as well as all stored metrics for the
    // current stroke.
    void clearStrokeData();

    // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
    // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
    // fuzzily match with the new ground truth point.
    void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);

    // Given a new prediction with targetTimestamp matching the latest ground truth point's
    // timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
    void updateAggregatedMetrics(const PredictionPoint& predictionPoint);

    // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
    void computeAtomFields();

    // Reports the current data in mAtomFields by calling mReportAtomFunction.
    void reportMetrics();
};

} // namespace android