summaryrefslogtreecommitdiff
path: root/include/input/TfLiteMotionPredictor.h
blob: 2edc138f67b126baf7a3c1c07c3f889b06881d17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/*
 * Copyright (C) 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <span>

#include <android-base/mapped_file.h>
#include <input/RingBuffer.h>
#include <utils/Timers.h>

#include <tensorflow/lite/core/api/error_reporter.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/model.h>
#include <tensorflow/lite/signature_runner.h>

namespace android {

struct TfLiteMotionPredictorSample {
    // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
    struct Point {
        float x;
        float y;
    } position;
    // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
    float pressure;
    float tilt;
    float orientation;
};

inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
                                                    const TfLiteMotionPredictorSample::Point& rhs) {
    return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
}

class TfLiteMotionPredictorModel;

// Buffer storage for a TfLiteMotionPredictorModel.
class TfLiteMotionPredictorBuffers {
public:
    // Creates buffer storage for a model with the given input length.
    TfLiteMotionPredictorBuffers(size_t inputLength);

    // Adds a motion sample to the buffers.
    void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);

    // Returns true if the buffers are complete enough to generate a prediction.
    bool isReady() const {
        // Predictions can't be applied unless there are at least two points to determine
        // the direction to apply them in.
        return mAxisFrom && mAxisTo;
    }

    // Resets all buffers to their initial state.
    void reset();

    // Copies the buffers to those of a model for prediction.
    void copyTo(TfLiteMotionPredictorModel& model) const;

    // Returns the current axis of the buffer's samples. Only valid if isReady().
    TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
    TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }

    // Returns the timestamp of the last sample.
    int64_t lastTimestamp() const { return mTimestamp; }

private:
    int64_t mTimestamp = 0;

    RingBuffer<float> mInputR;
    RingBuffer<float> mInputPhi;
    RingBuffer<float> mInputPressure;
    RingBuffer<float> mInputTilt;
    RingBuffer<float> mInputOrientation;

    // The samples defining the current polar axis.
    std::optional<TfLiteMotionPredictorSample> mAxisFrom;
    std::optional<TfLiteMotionPredictorSample> mAxisTo;
};

// A TFLite model for generating motion predictions.
class TfLiteMotionPredictorModel {
public:
    struct Config {
        // The time between predictions.
        nsecs_t predictionInterval = 0;
        // The noise floor for predictions.
        // Distances (r) less than this should be discarded as noise.
        float distanceNoiseFloor = 0;
    };

    // Creates a model from an encoded Flatbuffer model.
    static std::unique_ptr<TfLiteMotionPredictorModel> create();

    ~TfLiteMotionPredictorModel();

    // Returns the length of the model's input buffers.
    size_t inputLength() const;

    // Returns the length of the model's output buffers.
    size_t outputLength() const;

    const Config& config() const { return mConfig; }

    // Executes the model.
    // Returns true if the model successfully executed and the output tensors can be read.
    bool invoke();

    // Returns mutable buffers to the input tensors of inputLength() elements.
    std::span<float> inputR();
    std::span<float> inputPhi();
    std::span<float> inputPressure();
    std::span<float> inputOrientation();
    std::span<float> inputTilt();

    // Returns immutable buffers to the output tensors of identical length. Only valid after a
    // successful call to invoke().
    std::span<const float> outputR() const;
    std::span<const float> outputPhi() const;
    std::span<const float> outputPressure() const;

private:
    explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
                                        Config config);

    void allocateTensors();
    void attachInputTensors();
    void attachOutputTensors();

    TfLiteTensor* mInputR = nullptr;
    TfLiteTensor* mInputPhi = nullptr;
    TfLiteTensor* mInputPressure = nullptr;
    TfLiteTensor* mInputTilt = nullptr;
    TfLiteTensor* mInputOrientation = nullptr;

    const TfLiteTensor* mOutputR = nullptr;
    const TfLiteTensor* mOutputPhi = nullptr;
    const TfLiteTensor* mOutputPressure = nullptr;

    std::unique_ptr<android::base::MappedFile> mFlatBuffer;
    std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
    std::unique_ptr<tflite::FlatBufferModel> mModel;
    std::unique_ptr<tflite::Interpreter> mInterpreter;
    tflite::SignatureRunner* mRunner = nullptr;

    const Config mConfig = {};
};

} // namespace android