1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal;
18 
19 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
20 
21 import android.app.Activity;
22 import android.graphics.Rect;
23 import android.os.Bundle;
24 import android.os.Message;
25 import android.os.ParcelFileDescriptor;
26 import android.os.Process;
27 import android.os.SystemClock;
28 import android.util.Log;
29 
30 import androidx.test.filters.LargeTest;
31 
32 import com.android.internal.util.function.pooled.PooledLambda;
33 import com.android.internal.util.function.pooled.PooledPredicate;
34 
35 import org.junit.Assume;
36 import org.junit.Rule;
37 import org.junit.Test;
38 import org.junit.rules.TestRule;
39 import org.junit.runners.model.Statement;
40 
41 import java.io.BufferedReader;
42 import java.io.IOException;
43 import java.io.InputStreamReader;
44 import java.util.ArrayList;
45 import java.util.Arrays;
46 import java.util.List;
47 import java.util.concurrent.CountDownLatch;
48 import java.util.function.Predicate;
49 import java.util.regex.Matcher;
50 import java.util.regex.Pattern;
51 
52 /** Compares the performance of regular lambda and pooled lambda. */
53 @LargeTest
54 public class LambdaPerfTest {
55     private static final boolean DEBUG = false;
56     private static final String TAG = LambdaPerfTest.class.getSimpleName();
57 
58     private static final String LAMBDA_FORM_REGULAR = "regular";
59     private static final String LAMBDA_FORM_POOLED = "pooled";
60 
61     private static final int WARMUP_ITERATIONS = 1000;
62     private static final int TEST_ITERATIONS = 3000000;
63     private static final int TASK_COUNT = 10;
64     private static final long DELAY_AFTER_BENCH_MS = 1000;
65 
66     private String mMethodName;
67 
68     private final Bundle mTestResults = new Bundle();
69     private final ArrayList<Task> mTasks = new ArrayList<>();
70 
71     // The member fields are used to ensure lambda capturing. They don't have the actual meaning.
72     private final Task mTask = new Task();
73     private final Rect mBounds = new Rect();
74     private int mTaskId;
75     private long mTime;
76     private boolean mTop;
77 
78     @Rule
79     public final TestRule mRule = (base, description) -> new Statement() {
80         @Override
81         public void evaluate() throws Throwable {
82             mMethodName = description.getMethodName();
83             mTasks.clear();
84             for (int i = 0; i < TASK_COUNT; i++) {
85                 final Task t = new Task();
86                 mTasks.add(t);
87             }
88             base.evaluate();
89 
90             getInstrumentation().sendStatus(Activity.RESULT_OK, mTestResults);
91         }
92     };
93 
94     @Test
test1ParamPredicate()95     public void test1ParamPredicate() {
96         evaluate(LAMBDA_FORM_REGULAR, () -> handleTask(t -> t.doSomething(mTaskId, mTime)));
97         evaluate(LAMBDA_FORM_POOLED, () -> {
98             final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething,
99                     PooledLambda.__(Task.class), mTaskId, mTime);
100             handleTask(c);
101             c.recycle();
102         });
103     }
104 
105     @Test
test2PrimitiveParamsPredicate()106     public void test2PrimitiveParamsPredicate() {
107         // Not in Integer#IntegerCache (-128~127) for autoboxing, that may create new object.
108         mTaskId = 12345;
109         mTime = 54321;
110 
111         evaluate(LAMBDA_FORM_REGULAR, () -> handleTask(t -> t.doSomething(mTaskId, mTime)));
112         evaluate(LAMBDA_FORM_POOLED, () -> {
113             final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething,
114                     PooledLambda.__(Task.class), mTaskId, mTime);
115             handleTask(c);
116             c.recycle();
117         });
118     }
119 
120     @Test
test3ParamsPredicate()121     public void test3ParamsPredicate() {
122         mTop = true;
123         // In Integer#IntegerCache.
124         mTaskId = 10;
125 
126         evaluate(LAMBDA_FORM_REGULAR, () -> handleTask(t -> t.doSomething(mBounds, mTop, mTaskId)));
127         evaluate(LAMBDA_FORM_POOLED, () -> {
128             final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething,
129                     PooledLambda.__(Task.class), mBounds, mTop, mTaskId);
130             handleTask(c);
131             c.recycle();
132         });
133     }
134 
135     @Test
testMessage()136     public void testMessage() {
137         evaluate(LAMBDA_FORM_REGULAR, () -> {
138             final Message m = Message.obtain().setCallback(() -> mTask.doSomething(mTaskId, mTime));
139             m.getCallback().run();
140             m.recycle();
141         });
142         evaluate(LAMBDA_FORM_POOLED, () -> {
143             final Message m = PooledLambda.obtainMessage(Task::doSomething, mTask, mTaskId, mTime);
144             m.getCallback().run();
145             m.recycle();
146         });
147     }
148 
149     @Test
testRunnable()150     public void testRunnable() {
151         evaluate(LAMBDA_FORM_REGULAR, () -> {
152             final Runnable r = mTask::doSomething;
153             r.run();
154         });
155         evaluate(LAMBDA_FORM_POOLED, () -> {
156             final Runnable r = PooledLambda.obtainRunnable(Task::doSomething, mTask).recycleOnUse();
157             r.run();
158         });
159     }
160 
161     @Test
testMultiThread()162     public void testMultiThread() {
163         final int numThread = 3;
164 
165         final Runnable regularAction = () -> handleTask(t -> t.doSomething(mTaskId, mTime));
166         final Runnable[] regularActions = new Runnable[numThread];
167         Arrays.fill(regularActions, regularAction);
168         evaluateMultiThread(LAMBDA_FORM_REGULAR, regularActions);
169 
170         final Runnable pooledAction = () -> {
171             final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething,
172                     PooledLambda.__(Task.class), mTaskId, mTime);
173             handleTask(c);
174             c.recycle();
175         };
176         final Runnable[] pooledActions = new Runnable[numThread];
177         Arrays.fill(pooledActions, pooledAction);
178         evaluateMultiThread(LAMBDA_FORM_POOLED, pooledActions);
179     }
180 
handleTask(Predicate<Task> callback)181     private void handleTask(Predicate<Task> callback) {
182         for (int i = mTasks.size() - 1; i >= 0; i--) {
183             final Task task = mTasks.get(i);
184             if (callback.test(task)) {
185                 return;
186             }
187         }
188     }
189 
evaluate(String title, Runnable action)190     private void evaluate(String title, Runnable action) {
191         for (int i = 0; i < WARMUP_ITERATIONS; i++) {
192             action.run();
193         }
194         performGc();
195 
196         final GcStatus startGcStatus = getGcStatus();
197         final long startTime = SystemClock.elapsedRealtime();
198         for (int i = 0; i < TEST_ITERATIONS; i++) {
199             action.run();
200         }
201         evaluateResult(title, startGcStatus, startTime);
202     }
203 
evaluateMultiThread(String title, Runnable[] actions)204     private void evaluateMultiThread(String title, Runnable[] actions) {
205         performGc();
206 
207         final CountDownLatch latch = new CountDownLatch(actions.length);
208         final GcStatus startGcStatus = getGcStatus();
209         final long startTime = SystemClock.elapsedRealtime();
210         for (Runnable action : actions) {
211             new Thread() {
212                 @Override
213                 public void run() {
214                     for (int i = 0; i < TEST_ITERATIONS; i++) {
215                         action.run();
216                     }
217                     latch.countDown();
218                 };
219             }.start();
220         }
221         try {
222             latch.await();
223         } catch (InterruptedException ignored) {
224         }
225         evaluateResult(title, startGcStatus, startTime);
226     }
227 
evaluateResult(String title, GcStatus startStatus, long startTime)228     private void evaluateResult(String title, GcStatus startStatus, long startTime) {
229         final float elapsed = SystemClock.elapsedRealtime() - startTime;
230         // Sleep a while to see if GC may happen.
231         SystemClock.sleep(DELAY_AFTER_BENCH_MS);
232         final GcStatus endStatus = getGcStatus();
233         final GcInfo info = startStatus.calculateGcTime(endStatus, title, mTestResults);
234         mTestResults.putFloat("[" + title + "-execution-time]", elapsed);
235         Log.i(TAG, mMethodName + "_" + title + " execution time: "
236                 + elapsed + "ms (avg=" + String.format("%.5f", elapsed / TEST_ITERATIONS) + "ms)"
237                 + " GC time: " + String.format("%.3f", info.mTotalGcTime) + "ms"
238                 + " GC paused time: " + String.format("%.3f", info.mTotalGcPausedTime) + "ms");
239     }
240 
241     /** Cleans the test environment. */
performGc()242     private static void performGc() {
243         System.gc();
244         System.runFinalization();
245         System.gc();
246     }
247 
getGcStatus()248     private static GcStatus getGcStatus() {
249         if (DEBUG) {
250             Log.i(TAG, "===== Read GC dump =====");
251         }
252         final GcStatus status = new GcStatus();
253         final List<String> vmDump = getVmDump();
254         Assume.assumeFalse("VM dump is empty", vmDump.isEmpty());
255         for (String line : vmDump) {
256             status.visit(line);
257             if (line.startsWith("DALVIK THREADS")) {
258                 break;
259             }
260         }
261         return status;
262     }
263 
getVmDump()264     private static List<String> getVmDump() {
265         final int myPid = Process.myPid();
266         // Another approach Debug#dumpJavaBacktraceToFileTimeout requires setenforce 0.
267         Process.sendSignal(myPid, Process.SIGNAL_QUIT);
268         // Give a chance to handle the signal.
269         SystemClock.sleep(100);
270 
271         String dump = null;
272         final String pattern = myPid + " written to: ";
273         final List<String> logs = shell("logcat -v brief -d tombstoned:I *:S");
274         for (int i = logs.size() - 1; i >= 0; i--) {
275             final String log = logs.get(i);
276             // Log pattern: Traces for pid 9717 written to: /data/anr/trace_07
277             final int pos = log.indexOf(pattern);
278             if (pos > 0) {
279                 dump = log.substring(pattern.length() + pos);
280                 if (!dump.startsWith("/data/anr/")) {
281                     dump = "/data/anr/" + dump;
282                 }
283                 break;
284             }
285         }
286 
287         Assume.assumeNotNull("Unable to find VM dump", dump);
288         // It requires system or root uid to read the trace.
289         return shell("cat " + dump);
290     }
291 
shell(String command)292     private static List<String> shell(String command) {
293         final ParcelFileDescriptor.AutoCloseInputStream stream =
294                 new ParcelFileDescriptor.AutoCloseInputStream(
295                 getInstrumentation().getUiAutomation().executeShellCommand(command));
296         final ArrayList<String> lines = new ArrayList<>();
297         try (BufferedReader br = new BufferedReader(new InputStreamReader(stream))) {
298             String line;
299             while ((line = br.readLine()) != null) {
300                 lines.add(line);
301             }
302         } catch (IOException e) {
303             throw new RuntimeException(e);
304         }
305         return lines;
306     }
307 
308     /** An empty class which provides some methods with different type arguments. */
309     static class Task {
doSomething()310         void doSomething() {
311         }
312 
doSomething(int taskId, long time)313         boolean doSomething(int taskId, long time) {
314             return false;
315         }
316 
doSomething(Rect bounds, boolean top, int taskId)317         boolean doSomething(Rect bounds, boolean top, int taskId) {
318             return false;
319         }
320     }
321 
322     static class ValPattern {
323         static final int TYPE_COUNT = 0;
324         static final int TYPE_TIME = 1;
325         static final String PATTERN_COUNT = "(\\d+)";
326         static final String PATTERN_TIME = "(\\d+\\.?\\d+)(\\w+)";
327         final String mRawPattern;
328         final Pattern mPattern;
329         final int mType;
330 
331         int mIntValue;
332         float mFloatValue;
333 
ValPattern(String p, int type)334         ValPattern(String p, int type) {
335             mRawPattern = p;
336             mPattern = Pattern.compile(
337                     p + (type == TYPE_TIME ? PATTERN_TIME : PATTERN_COUNT) + ".*");
338             mType = type;
339         }
340 
visit(String line)341         boolean visit(String line) {
342             final Matcher matcher = mPattern.matcher(line);
343             if (!matcher.matches()) {
344                 return false;
345             }
346             final String value = matcher.group(1);
347             if (value == null) {
348                 return false;
349             }
350             if (mType == TYPE_COUNT) {
351                 mIntValue = Integer.parseInt(value);
352                 return true;
353             }
354             final float time = Float.parseFloat(value);
355             final String unit = matcher.group(2);
356             if (unit == null) {
357                 return false;
358             }
359             // Refer to art/libartbase/base/time_utils.cc
360             switch (unit) {
361                 case "s":
362                     mFloatValue = time * 1000;
363                     break;
364                 case "ms":
365                     mFloatValue = time;
366                     break;
367                 case "us":
368                     mFloatValue = time / 1000;
369                     break;
370                 case "ns":
371                     mFloatValue = time / 1000 / 1000;
372                     break;
373                 default:
374                     throw new IllegalArgumentException();
375             }
376 
377             return true;
378         }
379 
380         @Override
toString()381         public String toString() {
382             return mRawPattern + (mType == TYPE_TIME ? (mFloatValue + "ms") : mIntValue);
383         }
384     }
385 
386     /** Parses the dump pattern of Heap::DumpGcPerformanceInfo. */
387     private static class GcStatus {
388         private static final int TOTAL_GC_TIME_INDEX = 1;
389         private static final int TOTAL_GC_PAUSED_TIME_INDEX = 5;
390 
391         // Refer to art/runtime/gc/heap.cc
392         final ValPattern[] mPatterns = {
393                 new ValPattern("Total GC count: ", ValPattern.TYPE_COUNT),
394                 new ValPattern("Total GC time: ", ValPattern.TYPE_TIME),
395                 new ValPattern("Total time waiting for GC to complete: ", ValPattern.TYPE_TIME),
396                 new ValPattern("Total blocking GC count: ", ValPattern.TYPE_COUNT),
397                 new ValPattern("Total blocking GC time: ", ValPattern.TYPE_TIME),
398                 new ValPattern("Total mutator paused time: ", ValPattern.TYPE_TIME),
399                 new ValPattern("Total number of allocations ", ValPattern.TYPE_COUNT),
400                 new ValPattern("concurrent copying paused:  Sum: ", ValPattern.TYPE_TIME),
401                 new ValPattern("concurrent copying total time: ", ValPattern.TYPE_TIME),
402                 new ValPattern("concurrent copying freed: ", ValPattern.TYPE_COUNT),
403                 new ValPattern("Peak regions allocated ", ValPattern.TYPE_COUNT),
404         };
405 
visit(String dumpLine)406         void visit(String dumpLine) {
407             for (ValPattern p : mPatterns) {
408                 if (p.visit(dumpLine)) {
409                     if (DEBUG) {
410                         Log.i(TAG, "  " + p);
411                     }
412                 }
413             }
414         }
415 
calculateGcTime(GcStatus newStatus, String title, Bundle result)416         GcInfo calculateGcTime(GcStatus newStatus, String title, Bundle result) {
417             Log.i(TAG, "===== GC status of " + title + " =====");
418             final GcInfo info = new GcInfo();
419             for (int i = 0; i < mPatterns.length; i++) {
420                 final ValPattern p = mPatterns[i];
421                 if (p.mType == ValPattern.TYPE_COUNT) {
422                     final int diff = newStatus.mPatterns[i].mIntValue - p.mIntValue;
423                     Log.i(TAG, "  " + p.mRawPattern + diff);
424                     if (diff > 0) {
425                         result.putInt("[" + title + "] " + p.mRawPattern, diff);
426                     }
427                     continue;
428                 }
429                 final float diff = newStatus.mPatterns[i].mFloatValue - p.mFloatValue;
430                 Log.i(TAG, "  " + p.mRawPattern + diff + "ms");
431                 if (diff > 0) {
432                     result.putFloat("[" + title + "] " + p.mRawPattern + "(ms)", diff);
433                 }
434                 if (i == TOTAL_GC_TIME_INDEX) {
435                     info.mTotalGcTime = diff;
436                 } else if (i == TOTAL_GC_PAUSED_TIME_INDEX) {
437                     info.mTotalGcPausedTime = diff;
438                 }
439             }
440             return info;
441         }
442     }
443 
444     private static class GcInfo {
445         float mTotalGcTime;
446         float mTotalGcPausedTime;
447     }
448 }
449