1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.internal; 18 19 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation; 20 21 import android.app.Activity; 22 import android.graphics.Rect; 23 import android.os.Bundle; 24 import android.os.Message; 25 import android.os.ParcelFileDescriptor; 26 import android.os.Process; 27 import android.os.SystemClock; 28 import android.util.Log; 29 30 import androidx.test.filters.LargeTest; 31 32 import com.android.internal.util.function.pooled.PooledLambda; 33 import com.android.internal.util.function.pooled.PooledPredicate; 34 35 import org.junit.Assume; 36 import org.junit.Rule; 37 import org.junit.Test; 38 import org.junit.rules.TestRule; 39 import org.junit.runners.model.Statement; 40 41 import java.io.BufferedReader; 42 import java.io.IOException; 43 import java.io.InputStreamReader; 44 import java.util.ArrayList; 45 import java.util.Arrays; 46 import java.util.List; 47 import java.util.concurrent.CountDownLatch; 48 import java.util.function.Predicate; 49 import java.util.regex.Matcher; 50 import java.util.regex.Pattern; 51 52 /** Compares the performance of regular lambda and pooled lambda. */ 53 @LargeTest 54 public class LambdaPerfTest { 55 private static final boolean DEBUG = false; 56 private static final String TAG = LambdaPerfTest.class.getSimpleName(); 57 58 private static final String LAMBDA_FORM_REGULAR = "regular"; 59 private static final String LAMBDA_FORM_POOLED = "pooled"; 60 61 private static final int WARMUP_ITERATIONS = 1000; 62 private static final int TEST_ITERATIONS = 3000000; 63 private static final int TASK_COUNT = 10; 64 private static final long DELAY_AFTER_BENCH_MS = 1000; 65 66 private String mMethodName; 67 68 private final Bundle mTestResults = new Bundle(); 69 private final ArrayList<Task> mTasks = new ArrayList<>(); 70 71 // The member fields are used to ensure lambda capturing. They don't have the actual meaning. 72 private final Task mTask = new Task(); 73 private final Rect mBounds = new Rect(); 74 private int mTaskId; 75 private long mTime; 76 private boolean mTop; 77 78 @Rule 79 public final TestRule mRule = (base, description) -> new Statement() { 80 @Override 81 public void evaluate() throws Throwable { 82 mMethodName = description.getMethodName(); 83 mTasks.clear(); 84 for (int i = 0; i < TASK_COUNT; i++) { 85 final Task t = new Task(); 86 mTasks.add(t); 87 } 88 base.evaluate(); 89 90 getInstrumentation().sendStatus(Activity.RESULT_OK, mTestResults); 91 } 92 }; 93 94 @Test test1ParamPredicate()95 public void test1ParamPredicate() { 96 evaluate(LAMBDA_FORM_REGULAR, () -> handleTask(t -> t.doSomething(mTaskId, mTime))); 97 evaluate(LAMBDA_FORM_POOLED, () -> { 98 final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething, 99 PooledLambda.__(Task.class), mTaskId, mTime); 100 handleTask(c); 101 c.recycle(); 102 }); 103 } 104 105 @Test test2PrimitiveParamsPredicate()106 public void test2PrimitiveParamsPredicate() { 107 // Not in Integer#IntegerCache (-128~127) for autoboxing, that may create new object. 108 mTaskId = 12345; 109 mTime = 54321; 110 111 evaluate(LAMBDA_FORM_REGULAR, () -> handleTask(t -> t.doSomething(mTaskId, mTime))); 112 evaluate(LAMBDA_FORM_POOLED, () -> { 113 final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething, 114 PooledLambda.__(Task.class), mTaskId, mTime); 115 handleTask(c); 116 c.recycle(); 117 }); 118 } 119 120 @Test test3ParamsPredicate()121 public void test3ParamsPredicate() { 122 mTop = true; 123 // In Integer#IntegerCache. 124 mTaskId = 10; 125 126 evaluate(LAMBDA_FORM_REGULAR, () -> handleTask(t -> t.doSomething(mBounds, mTop, mTaskId))); 127 evaluate(LAMBDA_FORM_POOLED, () -> { 128 final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething, 129 PooledLambda.__(Task.class), mBounds, mTop, mTaskId); 130 handleTask(c); 131 c.recycle(); 132 }); 133 } 134 135 @Test testMessage()136 public void testMessage() { 137 evaluate(LAMBDA_FORM_REGULAR, () -> { 138 final Message m = Message.obtain().setCallback(() -> mTask.doSomething(mTaskId, mTime)); 139 m.getCallback().run(); 140 m.recycle(); 141 }); 142 evaluate(LAMBDA_FORM_POOLED, () -> { 143 final Message m = PooledLambda.obtainMessage(Task::doSomething, mTask, mTaskId, mTime); 144 m.getCallback().run(); 145 m.recycle(); 146 }); 147 } 148 149 @Test testRunnable()150 public void testRunnable() { 151 evaluate(LAMBDA_FORM_REGULAR, () -> { 152 final Runnable r = mTask::doSomething; 153 r.run(); 154 }); 155 evaluate(LAMBDA_FORM_POOLED, () -> { 156 final Runnable r = PooledLambda.obtainRunnable(Task::doSomething, mTask).recycleOnUse(); 157 r.run(); 158 }); 159 } 160 161 @Test testMultiThread()162 public void testMultiThread() { 163 final int numThread = 3; 164 165 final Runnable regularAction = () -> handleTask(t -> t.doSomething(mTaskId, mTime)); 166 final Runnable[] regularActions = new Runnable[numThread]; 167 Arrays.fill(regularActions, regularAction); 168 evaluateMultiThread(LAMBDA_FORM_REGULAR, regularActions); 169 170 final Runnable pooledAction = () -> { 171 final PooledPredicate c = PooledLambda.obtainPredicate(Task::doSomething, 172 PooledLambda.__(Task.class), mTaskId, mTime); 173 handleTask(c); 174 c.recycle(); 175 }; 176 final Runnable[] pooledActions = new Runnable[numThread]; 177 Arrays.fill(pooledActions, pooledAction); 178 evaluateMultiThread(LAMBDA_FORM_POOLED, pooledActions); 179 } 180 handleTask(Predicate<Task> callback)181 private void handleTask(Predicate<Task> callback) { 182 for (int i = mTasks.size() - 1; i >= 0; i--) { 183 final Task task = mTasks.get(i); 184 if (callback.test(task)) { 185 return; 186 } 187 } 188 } 189 evaluate(String title, Runnable action)190 private void evaluate(String title, Runnable action) { 191 for (int i = 0; i < WARMUP_ITERATIONS; i++) { 192 action.run(); 193 } 194 performGc(); 195 196 final GcStatus startGcStatus = getGcStatus(); 197 final long startTime = SystemClock.elapsedRealtime(); 198 for (int i = 0; i < TEST_ITERATIONS; i++) { 199 action.run(); 200 } 201 evaluateResult(title, startGcStatus, startTime); 202 } 203 evaluateMultiThread(String title, Runnable[] actions)204 private void evaluateMultiThread(String title, Runnable[] actions) { 205 performGc(); 206 207 final CountDownLatch latch = new CountDownLatch(actions.length); 208 final GcStatus startGcStatus = getGcStatus(); 209 final long startTime = SystemClock.elapsedRealtime(); 210 for (Runnable action : actions) { 211 new Thread() { 212 @Override 213 public void run() { 214 for (int i = 0; i < TEST_ITERATIONS; i++) { 215 action.run(); 216 } 217 latch.countDown(); 218 }; 219 }.start(); 220 } 221 try { 222 latch.await(); 223 } catch (InterruptedException ignored) { 224 } 225 evaluateResult(title, startGcStatus, startTime); 226 } 227 evaluateResult(String title, GcStatus startStatus, long startTime)228 private void evaluateResult(String title, GcStatus startStatus, long startTime) { 229 final float elapsed = SystemClock.elapsedRealtime() - startTime; 230 // Sleep a while to see if GC may happen. 231 SystemClock.sleep(DELAY_AFTER_BENCH_MS); 232 final GcStatus endStatus = getGcStatus(); 233 final GcInfo info = startStatus.calculateGcTime(endStatus, title, mTestResults); 234 mTestResults.putFloat("[" + title + "-execution-time]", elapsed); 235 Log.i(TAG, mMethodName + "_" + title + " execution time: " 236 + elapsed + "ms (avg=" + String.format("%.5f", elapsed / TEST_ITERATIONS) + "ms)" 237 + " GC time: " + String.format("%.3f", info.mTotalGcTime) + "ms" 238 + " GC paused time: " + String.format("%.3f", info.mTotalGcPausedTime) + "ms"); 239 } 240 241 /** Cleans the test environment. */ performGc()242 private static void performGc() { 243 System.gc(); 244 System.runFinalization(); 245 System.gc(); 246 } 247 getGcStatus()248 private static GcStatus getGcStatus() { 249 if (DEBUG) { 250 Log.i(TAG, "===== Read GC dump ====="); 251 } 252 final GcStatus status = new GcStatus(); 253 final List<String> vmDump = getVmDump(); 254 Assume.assumeFalse("VM dump is empty", vmDump.isEmpty()); 255 for (String line : vmDump) { 256 status.visit(line); 257 if (line.startsWith("DALVIK THREADS")) { 258 break; 259 } 260 } 261 return status; 262 } 263 getVmDump()264 private static List<String> getVmDump() { 265 final int myPid = Process.myPid(); 266 // Another approach Debug#dumpJavaBacktraceToFileTimeout requires setenforce 0. 267 Process.sendSignal(myPid, Process.SIGNAL_QUIT); 268 // Give a chance to handle the signal. 269 SystemClock.sleep(100); 270 271 String dump = null; 272 final String pattern = myPid + " written to: "; 273 final List<String> logs = shell("logcat -v brief -d tombstoned:I *:S"); 274 for (int i = logs.size() - 1; i >= 0; i--) { 275 final String log = logs.get(i); 276 // Log pattern: Traces for pid 9717 written to: /data/anr/trace_07 277 final int pos = log.indexOf(pattern); 278 if (pos > 0) { 279 dump = log.substring(pattern.length() + pos); 280 if (!dump.startsWith("/data/anr/")) { 281 dump = "/data/anr/" + dump; 282 } 283 break; 284 } 285 } 286 287 Assume.assumeNotNull("Unable to find VM dump", dump); 288 // It requires system or root uid to read the trace. 289 return shell("cat " + dump); 290 } 291 shell(String command)292 private static List<String> shell(String command) { 293 final ParcelFileDescriptor.AutoCloseInputStream stream = 294 new ParcelFileDescriptor.AutoCloseInputStream( 295 getInstrumentation().getUiAutomation().executeShellCommand(command)); 296 final ArrayList<String> lines = new ArrayList<>(); 297 try (BufferedReader br = new BufferedReader(new InputStreamReader(stream))) { 298 String line; 299 while ((line = br.readLine()) != null) { 300 lines.add(line); 301 } 302 } catch (IOException e) { 303 throw new RuntimeException(e); 304 } 305 return lines; 306 } 307 308 /** An empty class which provides some methods with different type arguments. */ 309 static class Task { doSomething()310 void doSomething() { 311 } 312 doSomething(int taskId, long time)313 boolean doSomething(int taskId, long time) { 314 return false; 315 } 316 doSomething(Rect bounds, boolean top, int taskId)317 boolean doSomething(Rect bounds, boolean top, int taskId) { 318 return false; 319 } 320 } 321 322 static class ValPattern { 323 static final int TYPE_COUNT = 0; 324 static final int TYPE_TIME = 1; 325 static final String PATTERN_COUNT = "(\\d+)"; 326 static final String PATTERN_TIME = "(\\d+\\.?\\d+)(\\w+)"; 327 final String mRawPattern; 328 final Pattern mPattern; 329 final int mType; 330 331 int mIntValue; 332 float mFloatValue; 333 ValPattern(String p, int type)334 ValPattern(String p, int type) { 335 mRawPattern = p; 336 mPattern = Pattern.compile( 337 p + (type == TYPE_TIME ? PATTERN_TIME : PATTERN_COUNT) + ".*"); 338 mType = type; 339 } 340 visit(String line)341 boolean visit(String line) { 342 final Matcher matcher = mPattern.matcher(line); 343 if (!matcher.matches()) { 344 return false; 345 } 346 final String value = matcher.group(1); 347 if (value == null) { 348 return false; 349 } 350 if (mType == TYPE_COUNT) { 351 mIntValue = Integer.parseInt(value); 352 return true; 353 } 354 final float time = Float.parseFloat(value); 355 final String unit = matcher.group(2); 356 if (unit == null) { 357 return false; 358 } 359 // Refer to art/libartbase/base/time_utils.cc 360 switch (unit) { 361 case "s": 362 mFloatValue = time * 1000; 363 break; 364 case "ms": 365 mFloatValue = time; 366 break; 367 case "us": 368 mFloatValue = time / 1000; 369 break; 370 case "ns": 371 mFloatValue = time / 1000 / 1000; 372 break; 373 default: 374 throw new IllegalArgumentException(); 375 } 376 377 return true; 378 } 379 380 @Override toString()381 public String toString() { 382 return mRawPattern + (mType == TYPE_TIME ? (mFloatValue + "ms") : mIntValue); 383 } 384 } 385 386 /** Parses the dump pattern of Heap::DumpGcPerformanceInfo. */ 387 private static class GcStatus { 388 private static final int TOTAL_GC_TIME_INDEX = 1; 389 private static final int TOTAL_GC_PAUSED_TIME_INDEX = 5; 390 391 // Refer to art/runtime/gc/heap.cc 392 final ValPattern[] mPatterns = { 393 new ValPattern("Total GC count: ", ValPattern.TYPE_COUNT), 394 new ValPattern("Total GC time: ", ValPattern.TYPE_TIME), 395 new ValPattern("Total time waiting for GC to complete: ", ValPattern.TYPE_TIME), 396 new ValPattern("Total blocking GC count: ", ValPattern.TYPE_COUNT), 397 new ValPattern("Total blocking GC time: ", ValPattern.TYPE_TIME), 398 new ValPattern("Total mutator paused time: ", ValPattern.TYPE_TIME), 399 new ValPattern("Total number of allocations ", ValPattern.TYPE_COUNT), 400 new ValPattern("concurrent copying paused: Sum: ", ValPattern.TYPE_TIME), 401 new ValPattern("concurrent copying total time: ", ValPattern.TYPE_TIME), 402 new ValPattern("concurrent copying freed: ", ValPattern.TYPE_COUNT), 403 new ValPattern("Peak regions allocated ", ValPattern.TYPE_COUNT), 404 }; 405 visit(String dumpLine)406 void visit(String dumpLine) { 407 for (ValPattern p : mPatterns) { 408 if (p.visit(dumpLine)) { 409 if (DEBUG) { 410 Log.i(TAG, " " + p); 411 } 412 } 413 } 414 } 415 calculateGcTime(GcStatus newStatus, String title, Bundle result)416 GcInfo calculateGcTime(GcStatus newStatus, String title, Bundle result) { 417 Log.i(TAG, "===== GC status of " + title + " ====="); 418 final GcInfo info = new GcInfo(); 419 for (int i = 0; i < mPatterns.length; i++) { 420 final ValPattern p = mPatterns[i]; 421 if (p.mType == ValPattern.TYPE_COUNT) { 422 final int diff = newStatus.mPatterns[i].mIntValue - p.mIntValue; 423 Log.i(TAG, " " + p.mRawPattern + diff); 424 if (diff > 0) { 425 result.putInt("[" + title + "] " + p.mRawPattern, diff); 426 } 427 continue; 428 } 429 final float diff = newStatus.mPatterns[i].mFloatValue - p.mFloatValue; 430 Log.i(TAG, " " + p.mRawPattern + diff + "ms"); 431 if (diff > 0) { 432 result.putFloat("[" + title + "] " + p.mRawPattern + "(ms)", diff); 433 } 434 if (i == TOTAL_GC_TIME_INDEX) { 435 info.mTotalGcTime = diff; 436 } else if (i == TOTAL_GC_PAUSED_TIME_INDEX) { 437 info.mTotalGcPausedTime = diff; 438 } 439 } 440 return info; 441 } 442 } 443 444 private static class GcInfo { 445 float mTotalGcTime; 446 float mTotalGcPausedTime; 447 } 448 } 449