1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
5  * except in compliance with the License. You may obtain a copy of the License at
6  *
7  *      http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software distributed under the
10  * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11  * KIND, either express or implied. See the License for the specific language governing
12  * permissions and limitations under the License.
13  */
14 
15 package android.testing;
16 
17 import android.os.Handler;
18 import android.os.HandlerThread;
19 import android.os.Looper;
20 import android.os.Message;
21 import android.os.MessageQueue;
22 import android.os.TestLooperManager;
23 import android.util.ArrayMap;
24 
25 import androidx.test.InstrumentationRegistry;
26 
27 import org.junit.runners.model.FrameworkMethod;
28 
29 import java.lang.annotation.ElementType;
30 import java.lang.annotation.Retention;
31 import java.lang.annotation.RetentionPolicy;
32 import java.lang.annotation.Target;
33 import java.lang.reflect.Field;
34 import java.util.Map;
35 import java.util.concurrent.atomic.AtomicBoolean;
36 
37 /**
38  * This is a wrapper around {@link TestLooperManager} to make it easier to manage
39  * and provide an easy annotation for use with tests.
40  *
41  * @see TestableLooperTest TestableLooperTest for examples.
42  */
43 public class TestableLooper {
44 
45     /**
46      * Whether to hold onto the main thread through all tests in an attempt to
47      * catch crashes.
48      */
49     public static final boolean HOLD_MAIN_THREAD = false;
50     private static final Field MESSAGE_QUEUE_MESSAGES_FIELD;
51     private static final Field MESSAGE_NEXT_FIELD;
52     private static final Field MESSAGE_WHEN_FIELD;
53 
54     private Looper mLooper;
55     private MessageQueue mQueue;
56     private MessageHandler mMessageHandler;
57 
58     private Handler mHandler;
59     private TestLooperManager mQueueWrapper;
60 
61     static {
62         try {
63             MESSAGE_QUEUE_MESSAGES_FIELD = MessageQueue.class.getDeclaredField("mMessages");
64             MESSAGE_QUEUE_MESSAGES_FIELD.setAccessible(true);
65             MESSAGE_NEXT_FIELD = Message.class.getDeclaredField("next");
66             MESSAGE_NEXT_FIELD.setAccessible(true);
67             MESSAGE_WHEN_FIELD = Message.class.getDeclaredField("when");
68             MESSAGE_WHEN_FIELD.setAccessible(true);
69         } catch (NoSuchFieldException e) {
70             throw new RuntimeException("Failed to initialize TestableLooper", e);
71         }
72     }
73 
TestableLooper(Looper l)74     public TestableLooper(Looper l) throws Exception {
75         this(acquireLooperManager(l), l);
76     }
77 
TestableLooper(TestLooperManager wrapper, Looper l)78     private TestableLooper(TestLooperManager wrapper, Looper l) {
79         mQueueWrapper = wrapper;
80         setupQueue(l);
81     }
82 
TestableLooper(Looper looper, boolean b)83     private TestableLooper(Looper looper, boolean b) {
84         setupQueue(looper);
85     }
86 
getLooper()87     public Looper getLooper() {
88         return mLooper;
89     }
90 
setupQueue(Looper l)91     private void setupQueue(Looper l) {
92         mLooper = l;
93         mQueue = mLooper.getQueue();
94         mHandler = new Handler(mLooper);
95     }
96 
97     /**
98      * Must be called to release the looper when the test is complete, otherwise
99      * the looper will not be available for any subsequent tests. This is
100      * automatically handled for tests using {@link RunWithLooper}.
101      */
destroy()102     public void destroy() {
103         mQueueWrapper.release();
104         if (HOLD_MAIN_THREAD && mLooper == Looper.getMainLooper()) {
105             TestableInstrumentation.releaseMain();
106         }
107     }
108 
109     /**
110      * Sets a callback for all messages processed on this TestableLooper.
111      *
112      * @see {@link MessageHandler}
113      */
setMessageHandler(MessageHandler handler)114     public void setMessageHandler(MessageHandler handler) {
115         mMessageHandler = handler;
116     }
117 
118     /**
119      * Parse num messages from the message queue.
120      *
121      * @param num Number of messages to parse
122      */
processMessages(int num)123     public int processMessages(int num) {
124         return processMessagesInternal(num, null);
125     }
126 
processMessagesInternal(int num, Runnable barrierRunnable)127     private int processMessagesInternal(int num, Runnable barrierRunnable) {
128         for (int i = 0; i < num; i++) {
129             if (!processSingleMessage(barrierRunnable)) {
130                 return i + 1;
131             }
132         }
133         return num;
134     }
135 
136     /**
137      * Process up to a certain number of messages, not blocking if the queue has less messages than
138      * that
139      * @param num the maximum number of messages to process
140      * @return the number of messages processed. This will be at most {@code num}.
141      */
142 
processMessagesNonBlocking(int num)143     public int processMessagesNonBlocking(int num) {
144         final AtomicBoolean reachedBarrier = new AtomicBoolean(false);
145         Runnable barrierRunnable = () -> {
146             reachedBarrier.set(true);
147         };
148         mHandler.post(barrierRunnable);
149         waitForMessage(mQueueWrapper, mHandler, barrierRunnable);
150         try {
151             return processMessagesInternal(num, barrierRunnable) + (reachedBarrier.get() ? -1 : 0);
152         } finally {
153             mHandler.removeCallbacks(barrierRunnable);
154         }
155     }
156 
157     /**
158      * Process messages in the queue until no more are found.
159      */
processAllMessages()160     public void processAllMessages() {
161         while (processQueuedMessages() != 0) ;
162     }
163 
moveTimeForward(long milliSeconds)164     public void moveTimeForward(long milliSeconds) {
165         try {
166             Message msg = getMessageLinkedList();
167             while (msg != null) {
168                 long updatedWhen = msg.getWhen() - milliSeconds;
169                 if (updatedWhen < 0) {
170                     updatedWhen = 0;
171                 }
172                 MESSAGE_WHEN_FIELD.set(msg, updatedWhen);
173                 msg = (Message) MESSAGE_NEXT_FIELD.get(msg);
174             }
175         } catch (IllegalAccessException e) {
176             throw new RuntimeException("Access failed in TestableLooper: set - Message.when", e);
177         }
178     }
179 
getMessageLinkedList()180     private Message getMessageLinkedList() {
181         try {
182             MessageQueue queue = mLooper.getQueue();
183             return (Message) MESSAGE_QUEUE_MESSAGES_FIELD.get(queue);
184         } catch (IllegalAccessException e) {
185             throw new RuntimeException(
186                     "Access failed in TestableLooper: get - MessageQueue.mMessages",
187                     e);
188         }
189     }
190 
processQueuedMessages()191     private int processQueuedMessages() {
192         int count = 0;
193         Runnable barrierRunnable = () -> { };
194         mHandler.post(barrierRunnable);
195         waitForMessage(mQueueWrapper, mHandler, barrierRunnable);
196         while (processSingleMessage(barrierRunnable)) count++;
197         return count;
198     }
199 
processSingleMessage(Runnable barrierRunnable)200     private boolean processSingleMessage(Runnable barrierRunnable) {
201         try {
202             Message result = mQueueWrapper.next();
203             if (result != null) {
204                 // This is a break message.
205                 if (result.getCallback() == barrierRunnable) {
206                     mQueueWrapper.execute(result);
207                     mQueueWrapper.recycle(result);
208                     return false;
209                 }
210 
211                 if (mMessageHandler != null) {
212                     if (mMessageHandler.onMessageHandled(result)) {
213                         mQueueWrapper.execute(result);
214                         mQueueWrapper.recycle(result);
215                     } else {
216                         mQueueWrapper.recycle(result);
217                         // Message handler indicated it doesn't want us to continue.
218                         return false;
219                     }
220                 } else {
221                     mQueueWrapper.execute(result);
222                     mQueueWrapper.recycle(result);
223                 }
224             } else {
225                 // No messages, don't continue parsing
226                 return false;
227             }
228         } catch (Exception e) {
229             throw new RuntimeException(e);
230         }
231         return true;
232     }
233 
234     /**
235      * Runs an executable with myLooper set and processes all messages added.
236      */
runWithLooper(RunnableWithException runnable)237     public void runWithLooper(RunnableWithException runnable) throws Exception {
238         new Handler(getLooper()).post(() -> {
239             try {
240                 runnable.run();
241             } catch (Exception e) {
242                 throw new RuntimeException(e);
243             }
244         });
245         processAllMessages();
246     }
247 
248     public interface RunnableWithException {
run()249         void run() throws Exception;
250     }
251 
252     /**
253      * Annotation that tells the {@link AndroidTestingRunner} to create a TestableLooper and
254      * run this test/class on that thread. The {@link TestableLooper} can be acquired using
255      * {@link #get(Object)}.
256      */
257     @Retention(RetentionPolicy.RUNTIME)
258     @Target({ElementType.METHOD, ElementType.TYPE})
259     public @interface RunWithLooper {
setAsMainLooper()260         boolean setAsMainLooper() default false;
261     }
262 
waitForMessage(TestLooperManager queueWrapper, Handler handler, Runnable execute)263     private static void waitForMessage(TestLooperManager queueWrapper, Handler handler,
264             Runnable execute) {
265         for (int i = 0; i < 10; i++) {
266             if (!queueWrapper.hasMessages(handler, null, execute)) {
267                 try {
268                     Thread.sleep(1);
269                 } catch (InterruptedException e) {
270                 }
271             }
272         }
273         if (!queueWrapper.hasMessages(handler, null, execute)) {
274             throw new RuntimeException("Message didn't queue...");
275         }
276     }
277 
acquireLooperManager(Looper l)278     private static TestLooperManager acquireLooperManager(Looper l) {
279         if (HOLD_MAIN_THREAD && l == Looper.getMainLooper()) {
280             TestableInstrumentation.acquireMain();
281         }
282         return InstrumentationRegistry.getInstrumentation().acquireLooperManager(l);
283     }
284 
285     private static final Map<Object, TestableLooper> sLoopers = new ArrayMap<>();
286 
287     /**
288      * For use with {@link RunWithLooper}, used to get the TestableLooper that was
289      * automatically created for this test.
290      */
get(Object test)291     public static TestableLooper get(Object test) {
292         return sLoopers.get(test);
293     }
294 
remove(Object test)295     public static void remove(Object test) {
296         sLoopers.remove(test);
297     }
298 
299     static class LooperFrameworkMethod extends FrameworkMethod {
300         private HandlerThread mHandlerThread;
301 
302         private final TestableLooper mTestableLooper;
303         private final Looper mLooper;
304         private final Handler mHandler;
305 
LooperFrameworkMethod(FrameworkMethod base, boolean setAsMain, Object test)306         public LooperFrameworkMethod(FrameworkMethod base, boolean setAsMain, Object test) {
307             super(base.getMethod());
308             try {
309                 mLooper = setAsMain ? Looper.getMainLooper() : createLooper();
310                 mTestableLooper = new TestableLooper(mLooper, false);
311                 if (!setAsMain) {
312                     mTestableLooper.getLooper().getThread().setName(test.getClass().getName());
313                 }
314             } catch (Exception e) {
315                 throw new RuntimeException(e);
316             }
317             sLoopers.put(test, mTestableLooper);
318             mHandler = new Handler(mLooper);
319         }
320 
LooperFrameworkMethod(TestableLooper other, FrameworkMethod base)321         public LooperFrameworkMethod(TestableLooper other, FrameworkMethod base) {
322             super(base.getMethod());
323             mLooper = other.mLooper;
324             mTestableLooper = other;
325             mHandler = Handler.createAsync(mLooper);
326         }
327 
get(FrameworkMethod base, boolean setAsMain, Object test)328         public static FrameworkMethod get(FrameworkMethod base, boolean setAsMain, Object test) {
329             if (sLoopers.containsKey(test)) {
330                 return new LooperFrameworkMethod(sLoopers.get(test), base);
331             }
332             return new LooperFrameworkMethod(base, setAsMain, test);
333         }
334 
335         @Override
invokeExplosively(Object target, Object... params)336         public Object invokeExplosively(Object target, Object... params) throws Throwable {
337             if (Looper.myLooper() == mLooper) {
338                 // Already on the right thread from another statement, just execute then.
339                 return super.invokeExplosively(target, params);
340             }
341             boolean set = mTestableLooper.mQueueWrapper == null;
342             if (set) {
343                 mTestableLooper.mQueueWrapper = acquireLooperManager(mLooper);
344             }
345             try {
346                 Object[] ret = new Object[1];
347                 // Run the execution on the looper thread.
348                 Runnable execute = () -> {
349                     try {
350                         ret[0] = super.invokeExplosively(target, params);
351                     } catch (Throwable throwable) {
352                         throw new LooperException(throwable);
353                     }
354                 };
355                 Message m = Message.obtain(mHandler, execute);
356 
357                 // Dispatch our message.
358                 try {
359                     mTestableLooper.mQueueWrapper.execute(m);
360                 } catch (LooperException e) {
361                     throw e.getSource();
362                 } catch (RuntimeException re) {
363                     // If the TestLooperManager has to post, it will wrap what it throws in a
364                     // RuntimeException, make sure we grab the actual source.
365                     if (re.getCause() instanceof LooperException) {
366                         throw ((LooperException) re.getCause()).getSource();
367                     } else {
368                         throw re.getCause();
369                     }
370                 } finally {
371                     m.recycle();
372                 }
373                 return ret[0];
374             } finally {
375                 if (set) {
376                     mTestableLooper.mQueueWrapper.release();
377                     mTestableLooper.mQueueWrapper = null;
378                     if (HOLD_MAIN_THREAD && mLooper == Looper.getMainLooper()) {
379                         TestableInstrumentation.releaseMain();
380                     }
381                 }
382             }
383         }
384 
createLooper()385         private Looper createLooper() {
386             // TODO: Find way to share these.
387             mHandlerThread = new HandlerThread(TestableLooper.class.getSimpleName());
388             mHandlerThread.start();
389             return mHandlerThread.getLooper();
390         }
391 
392         @Override
finalize()393         protected void finalize() throws Throwable {
394             super.finalize();
395             if (mHandlerThread != null) {
396                 mHandlerThread.quit();
397             }
398         }
399 
400         private static class LooperException extends RuntimeException {
401             private final Throwable mSource;
402 
LooperException(Throwable t)403             public LooperException(Throwable t) {
404                 mSource = t;
405             }
406 
getSource()407             public Throwable getSource() {
408                 return mSource;
409             }
410         }
411     }
412 
413     /**
414      * Callback to control the execution of messages on the looper, when set with
415      * {@link #setMessageHandler(MessageHandler)} then {@link #onMessageHandled(Message)}
416      * will get called back for every message processed on the {@link TestableLooper}.
417      */
418     public interface MessageHandler {
419         /**
420          * Return true to have the message executed and delivered to target.
421          * Return false to not execute the message and stop executing messages.
422          */
onMessageHandled(Message m)423         boolean onMessageHandled(Message m);
424     }
425 }
426