summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTreeHugger Robot <treehugger-gerrit@google.com>2022-07-20 15:43:25 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2022-07-20 15:43:25 +0000
commit5c7f06a2ee4e6fac8be2c7cb3274294e83a94291 (patch)
tree18c31f935231acf8854fed342213c90e1443fb01
parenta2a2d0bd0c8c8a6052b0f082a467ea4eb2618cd1 (diff)
parent7713b243f1de0849cdd20461458a2c117d672de0 (diff)
downloadbase-5c7f06a2ee4e6fac8be2c7cb3274294e83a94291.tar.gz
Merge "Add benchmarks for WASM bidding logic" into tm-mainline-prod
-rw-r--r--apct-tests/perftests/rubidium/Android.bp1
-rw-r--r--apct-tests/perftests/rubidium/assets/generate_bid.wasmbin0 -> 1308806 bytes
-rw-r--r--apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js24
-rw-r--r--apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java97
4 files changed, 110 insertions, 12 deletions
diff --git a/apct-tests/perftests/rubidium/Android.bp b/apct-tests/perftests/rubidium/Android.bp
index 339ef3066ca0..ba2b44241c5a 100644
--- a/apct-tests/perftests/rubidium/Android.bp
+++ b/apct-tests/perftests/rubidium/Android.bp
@@ -33,6 +33,7 @@ android_test {
"compatibility-device-util-axt",
"platform-test-annotations",
"adservices-service-core",
+ "androidx.core_core",
],
test_suites: ["device-tests"],
data: [":perfetto_artifacts"],
diff --git a/apct-tests/perftests/rubidium/assets/generate_bid.wasm b/apct-tests/perftests/rubidium/assets/generate_bid.wasm
new file mode 100644
index 000000000000..5e7fe9ee5fe7
--- /dev/null
+++ b/apct-tests/perftests/rubidium/assets/generate_bid.wasm
Binary files differ
diff --git a/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js b/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js
new file mode 100644
index 000000000000..bc50d0af0954
--- /dev/null
+++ b/apct-tests/perftests/rubidium/assets/generate_bid_using_wasm.js
@@ -0,0 +1,24 @@
+function generateBid(ad, wasmModule) {
+ let input = ad.metadata.input;
+
+ const instance = new WebAssembly.Instance(wasmModule);
+
+ const memory = instance.exports.memory;
+ const input_in_memory = new Float32Array(memory.buffer, 0, 200);
+ for (let i = 0; i < input.length; ++i) {
+ input_in_memory[i] = input[i];
+ }
+ const results = [
+ instance.exports.nn_forward_model0(input_in_memory.length, input_in_memory),
+ instance.exports.nn_forward_model1(input_in_memory.length, input_in_memory),
+ instance.exports.nn_forward_model2(input_in_memory.length, input_in_memory),
+ instance.exports.nn_forward_model3(input_in_memory.length, input_in_memory),
+ instance.exports.nn_forward_model4(input_in_memory.length, input_in_memory),
+ ];
+ const bid = results.map(x => Math.max(x, 1)).reduce((x, y) => x * y);
+ return {
+ ad: 'example',
+ bid: bid,
+ render: ad.renderUrl
+ }
+} \ No newline at end of file
diff --git a/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java b/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java
index bf9ff3a47c40..0ddec236b6da 100644
--- a/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java
+++ b/apct-tests/perftests/rubidium/src/android/rubidium/js/JSScriptEnginePerfTests.java
@@ -24,6 +24,9 @@ import static com.android.adservices.service.js.JSScriptArgument.stringArrayArg;
import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assume.assumeTrue;
+
+import android.annotation.SuppressLint;
import android.content.Context;
import android.perftests.utils.BenchmarkState;
import android.perftests.utils.PerfStatusReporter;
@@ -45,48 +48,44 @@ import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import org.json.JSONArray;
-import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
+import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
+import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
/** To run the unit tests for this class, run "atest RubidiumPerfTests:JSScriptEnginePerfTests" */
@MediumTest
@RunWith(AndroidJUnit4.class)
public class JSScriptEnginePerfTests {
- private static final String TAG = JSScriptEnginePerfTests.class.getSimpleName();
+ private static final String TAG = JSScriptEngine.TAG;
private static final Context sContext = ApplicationProvider.getApplicationContext();
private static final ExecutorService sExecutorService = Executors.newFixedThreadPool(10);
- private static JSScriptEngine sJSScriptEngine;
+ private static final JSScriptEngine sJSScriptEngine =
+ JSScriptEngine.getInstanceForTesting(
+ sContext, Profiler.createInstance(JSScriptEngine.TAG));
@Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter();
@Before
public void before() throws Exception {
- Profiler profiler = Profiler.createInstance(JSScriptEngine.TAG);
- sJSScriptEngine = JSScriptEngine.getInstanceForTesting(sContext, profiler);
-
// Warm up the sandbox env.
callJSEngine(
"function test() { return \"hello world\";" + " }", ImmutableList.of(), "test");
}
- @After
- public void after() {
- sJSScriptEngine.shutdown();
- }
-
@Test
public void evaluate_helloWorld() throws Exception {
BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
@@ -156,6 +155,7 @@ public class JSScriptEnginePerfTests {
runParametrizedTurtledoveScript(75);
}
+ @SuppressLint("DefaultLocale")
private void runParametrizedTurtledoveScript(int numAds) throws Exception {
BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
state.pauseTiming();
@@ -220,7 +220,34 @@ public class JSScriptEnginePerfTests {
return arrayArg("foo", Collections.nCopies(numCustomAudiences, interestGroupArg));
}
- private static String callJSEngine(
+ @Test
+ public void evaluate_turtledoveWasm() throws Exception {
+ assumeTrue(sJSScriptEngine.isWasmSupported().get(3, TimeUnit.SECONDS));
+
+ BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
+ state.pauseTiming();
+
+ String jsTestFile = readAsset("generate_bid_using_wasm.js");
+ byte[] wasmTestFile = readBinaryAsset("generate_bid.wasm");
+ JSScriptArgument[] inputBytes = new JSScriptArgument[200];
+ Random rand = new Random();
+ for (int i = 0; i < inputBytes.length; i++) {
+ byte value = (byte) (rand.nextInt(2 * Byte.MAX_VALUE) - Byte.MIN_VALUE);
+ inputBytes[i] = JSScriptArgument.numericArg("_", value);
+ }
+ JSScriptArgument adDataArgument =
+ recordArg(
+ "ad",
+ stringArg("render_url", "http://google.com"),
+ recordArg("metadata", JSScriptArgument.arrayArg("input", inputBytes)));
+
+ state.resumeTiming();
+ while (state.keepRunning()) {
+ callJSEngine(jsTestFile, wasmTestFile, ImmutableList.of(adDataArgument), "generateBid");
+ }
+ }
+
+ private String callJSEngine(
@NonNull String jsScript,
@NonNull List<JSScriptArgument> args,
@NonNull String functionName)
@@ -228,6 +255,15 @@ public class JSScriptEnginePerfTests {
return callJSEngine(sJSScriptEngine, jsScript, args, functionName);
}
+ private String callJSEngine(
+ @NonNull String jsScript,
+ @NonNull byte[] wasmScript,
+ @NonNull List<JSScriptArgument> args,
+ @NonNull String functionName)
+ throws Exception {
+ return callJSEngine(sJSScriptEngine, jsScript, wasmScript, args, functionName);
+ }
+
private static String callJSEngine(
@NonNull JSScriptEngine jsScriptEngine,
@NonNull String jsScript,
@@ -241,6 +277,21 @@ public class JSScriptEnginePerfTests {
return futureResult.get();
}
+ private String callJSEngine(
+ @NonNull JSScriptEngine jsScriptEngine,
+ @NonNull String jsScript,
+ @NonNull byte[] wasmScript,
+ @NonNull List<JSScriptArgument> args,
+ @NonNull String functionName)
+ throws Exception {
+ CountDownLatch resultLatch = new CountDownLatch(1);
+ ListenableFuture<String> futureResult =
+ callJSEngineAsync(
+ jsScriptEngine, jsScript, wasmScript, args, functionName, resultLatch);
+ resultLatch.await();
+ return futureResult.get();
+ }
+
private static ListenableFuture<String> callJSEngineAsync(
@NonNull String jsScript,
@NonNull List<JSScriptArgument> args,
@@ -261,4 +312,26 @@ public class JSScriptEnginePerfTests {
result.addListener(resultLatch::countDown, sExecutorService);
return result;
}
+
+ private ListenableFuture<String> callJSEngineAsync(
+ @NonNull JSScriptEngine engine,
+ @NonNull String jsScript,
+ @NonNull byte[] wasmScript,
+ @NonNull List<JSScriptArgument> args,
+ @NonNull String functionName,
+ @NonNull CountDownLatch resultLatch) {
+ Objects.requireNonNull(engine);
+ Objects.requireNonNull(resultLatch);
+ ListenableFuture<String> result = engine.evaluate(jsScript, wasmScript, args, functionName);
+ result.addListener(resultLatch::countDown, sExecutorService);
+ return result;
+ }
+
+ private byte[] readBinaryAsset(@NonNull String assetName) throws IOException {
+ return sContext.getAssets().open(assetName).readAllBytes();
+ }
+
+ private String readAsset(@NonNull String assetName) throws IOException {
+ return new String(readBinaryAsset(assetName), StandardCharsets.UTF_8);
+ }
}