diff options
author | TreeHugger Robot <treehugger-gerrit@google.com> | 2022-07-20 15:43:25 +0000 |
---|---|---|
committer | Android (Google) Code Review <android-gerrit@google.com> | 2022-07-20 15:43:25 +0000 |
commit | 5c7f06a2ee4e6fac8be2c7cb3274294e83a94291 (patch) | |
tree | 18c31f935231acf8854fed342213c90e1443fb01 | |
parent | a2a2d0bd0c8c8a6052b0f082a467ea4eb2618cd1 (diff) | |
parent | 7713b243f1de0849cdd20461458a2c117d672de0 (diff) | |
download | base-5c7f06a2ee4e6fac8be2c7cb3274294e83a94291.tar.gz |
Merge "Add benchmarks for WASM bidding logic" into tm-mainline-prod
-rw-r--r-- | apct-tests/perftests/rubidium/Android.bp | 1 | ||||
-rw-r--r-- | apct-tests/perftests/rubidium/assets/generate_bid.wasm | bin | 0 -> 1308806 bytes | |||
-rw-r--r-- | apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js | 24 | ||||
-rw-r--r-- | apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java | 97 |
4 files changed, 110 insertions, 12 deletions
diff --git a/apct-tests/perftests/rubidium/Android.bp b/apct-tests/perftests/rubidium/Android.bp index 339ef3066ca0..ba2b44241c5a 100644 --- a/apct-tests/perftests/rubidium/Android.bp +++ b/apct-tests/perftests/rubidium/Android.bp @@ -33,6 +33,7 @@ android_test { "compatibility-device-util-axt", "platform-test-annotations", "adservices-service-core", + "androidx.core_core", ], test_suites: ["device-tests"], data: [":perfetto_artifacts"], diff --git a/apct-tests/perftests/rubidium/assets/generate_bid.wasm b/apct-tests/perftests/rubidium/assets/generate_bid.wasm Binary files differnew file mode 100644 index 000000000000..5e7fe9ee5fe7 --- /dev/null +++ b/apct-tests/perftests/rubidium/assets/generate_bid.wasm diff --git a/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js b/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js new file mode 100644 index 000000000000..bc50d0af0954 --- /dev/null +++ b/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js @@ -0,0 +1,24 @@ +function generateBid(ad, wasmModule) { + let input = ad.metadata.input; + + const instance = new WebAssembly.Instance(wasmModule); + + const memory = instance.exports.memory; + const input_in_memory = new Float32Array(memory.buffer, 0, 200); + for (let i = 0; i < input.length; ++i) { + input_in_memory[i] = input[i]; + } + const results = [ + instance.exports.nn_forward_model0(input_in_memory.length, input_in_memory), + instance.exports.nn_forward_model1(input_in_memory.length, input_in_memory), + instance.exports.nn_forward_model2(input_in_memory.length, input_in_memory), + instance.exports.nn_forward_model3(input_in_memory.length, input_in_memory), + instance.exports.nn_forward_model4(input_in_memory.length, input_in_memory), + ]; + const bid = results.map(x => Math.max(x, 1)).reduce((x, y) => x * y); + return { + ad: 'example', + bid: bid, + render: ad.renderUrl + } +}
\ No newline at end of file diff --git a/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java b/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java index bf9ff3a47c40..0ddec236b6da 100644 --- a/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java +++ b/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java @@ -24,6 +24,9 @@ import static com.android.adservices.service.js.JSScriptArgument.stringArrayArg; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assume.assumeTrue; + +import android.annotation.SuppressLint; import android.content.Context; import android.perftests.utils.BenchmarkState; import android.perftests.utils.PerfStatusReporter; @@ -45,48 +48,44 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import org.json.JSONArray; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; /** To run the unit tests for this class, run "atest RubidiumPerfTests:JSScriptEnginePerfTests" */ @MediumTest @RunWith(AndroidJUnit4.class) public class JSScriptEnginePerfTests { - private static final String TAG = JSScriptEnginePerfTests.class.getSimpleName(); + private static final String TAG = JSScriptEngine.TAG; private static final Context sContext = ApplicationProvider.getApplicationContext(); private static final ExecutorService sExecutorService = Executors.newFixedThreadPool(10); - private static JSScriptEngine sJSScriptEngine; + private static final JSScriptEngine sJSScriptEngine = + JSScriptEngine.getInstanceForTesting( + sContext, Profiler.createInstance(JSScriptEngine.TAG)); @Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter(); @Before public void before() throws Exception { - Profiler profiler = Profiler.createInstance(JSScriptEngine.TAG); - sJSScriptEngine = JSScriptEngine.getInstanceForTesting(sContext, profiler); - // Warm up the sandbox env. callJSEngine( "function test() { return \"hello world\";" + " }", ImmutableList.of(), "test"); } - @After - public void after() { - sJSScriptEngine.shutdown(); - } - @Test public void evaluate_helloWorld() throws Exception { BenchmarkState state = mPerfStatusReporter.getBenchmarkState(); @@ -156,6 +155,7 @@ public class JSScriptEnginePerfTests { runParametrizedTurtledoveScript(75); } + @SuppressLint("DefaultLocale") private void runParametrizedTurtledoveScript(int numAds) throws Exception { BenchmarkState state = mPerfStatusReporter.getBenchmarkState(); state.pauseTiming(); @@ -220,7 +220,34 @@ public class JSScriptEnginePerfTests { return arrayArg("foo", Collections.nCopies(numCustomAudiences, interestGroupArg)); } - private static String callJSEngine( + @Test + public void evaluate_turtledoveWasm() throws Exception { + assumeTrue(sJSScriptEngine.isWasmSupported().get(3, TimeUnit.SECONDS)); + + BenchmarkState state = mPerfStatusReporter.getBenchmarkState(); + state.pauseTiming(); + + String jsTestFile = readAsset("generate_bid_using_wasm.js"); + byte[] wasmTestFile = readBinaryAsset("generate_bid.wasm"); + JSScriptArgument[] inputBytes = new JSScriptArgument[200]; + Random rand = new Random(); + for (int i = 0; i < inputBytes.length; i++) { + byte value = (byte) (rand.nextInt(2 * Byte.MAX_VALUE) - Byte.MIN_VALUE); + inputBytes[i] = JSScriptArgument.numericArg("_", value); + } + JSScriptArgument adDataArgument = + recordArg( + "ad", + stringArg("render_url", "http://google.com"), + recordArg("metadata", JSScriptArgument.arrayArg("input", inputBytes))); + + state.resumeTiming(); + while (state.keepRunning()) { + callJSEngine(jsTestFile, wasmTestFile, ImmutableList.of(adDataArgument), "generateBid"); + } + } + + private String callJSEngine( @NonNull String jsScript, @NonNull List<JSScriptArgument> args, @NonNull String functionName) @@ -228,6 +255,15 @@ public class JSScriptEnginePerfTests { return callJSEngine(sJSScriptEngine, jsScript, args, functionName); } + private String callJSEngine( + @NonNull String jsScript, + @NonNull byte[] wasmScript, + @NonNull List<JSScriptArgument> args, + @NonNull String functionName) + throws Exception { + return callJSEngine(sJSScriptEngine, jsScript, wasmScript, args, functionName); + } + private static String callJSEngine( @NonNull JSScriptEngine jsScriptEngine, @NonNull String jsScript, @@ -241,6 +277,21 @@ public class JSScriptEnginePerfTests { return futureResult.get(); } + private String callJSEngine( + @NonNull JSScriptEngine jsScriptEngine, + @NonNull String jsScript, + @NonNull byte[] wasmScript, + @NonNull List<JSScriptArgument> args, + @NonNull String functionName) + throws Exception { + CountDownLatch resultLatch = new CountDownLatch(1); + ListenableFuture<String> futureResult = + callJSEngineAsync( + jsScriptEngine, jsScript, wasmScript, args, functionName, resultLatch); + resultLatch.await(); + return futureResult.get(); + } + private static ListenableFuture<String> callJSEngineAsync( @NonNull String jsScript, @NonNull List<JSScriptArgument> args, @@ -261,4 +312,26 @@ public class JSScriptEnginePerfTests { result.addListener(resultLatch::countDown, sExecutorService); return result; } + + private ListenableFuture<String> callJSEngineAsync( + @NonNull JSScriptEngine engine, + @NonNull String jsScript, + @NonNull byte[] wasmScript, + @NonNull List<JSScriptArgument> args, + @NonNull String functionName, + @NonNull CountDownLatch resultLatch) { + Objects.requireNonNull(engine); + Objects.requireNonNull(resultLatch); + ListenableFuture<String> result = engine.evaluate(jsScript, wasmScript, args, functionName); + result.addListener(resultLatch::countDown, sExecutorService); + return result; + } + + private byte[] readBinaryAsset(@NonNull String assetName) throws IOException { + return sContext.getAssets().open(assetName).readAllBytes(); + } + + private String readAsset(@NonNull String assetName) throws IOException { + return new String(readBinaryAsset(assetName), StandardCharsets.UTF_8); + } } |