1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static android.net.DnsResolver.CLASS_IN;
20 import static android.net.DnsResolver.FLAG_EMPTY;
21 import static android.net.DnsResolver.FLAG_NO_CACHE_LOOKUP;
22 import static android.net.DnsResolver.TYPE_A;
23 import static android.net.DnsResolver.TYPE_AAAA;
24 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
25 import static android.net.cts.util.CtsNetUtils.TestNetworkCallback;
26 import static android.system.OsConstants.ETIMEDOUT;
27 
28 import android.annotation.NonNull;
29 import android.annotation.Nullable;
30 import android.content.Context;
31 import android.content.ContentResolver;
32 import android.content.pm.PackageManager;
33 import android.net.ConnectivityManager;
34 import android.net.ConnectivityManager.NetworkCallback;
35 import android.net.DnsResolver;
36 import android.net.LinkProperties;
37 import android.net.Network;
38 import android.net.NetworkCapabilities;
39 import android.net.NetworkRequest;
40 import android.net.ParseException;
41 import android.net.cts.util.CtsNetUtils;
42 import android.os.CancellationSignal;
43 import android.os.Handler;
44 import android.os.Looper;
45 import android.platform.test.annotations.AppModeFull;
46 import android.provider.Settings;
47 import android.system.ErrnoException;
48 import android.test.AndroidTestCase;
49 import android.util.Log;
50 
51 import com.android.net.module.util.DnsPacket;
52 import com.android.testutils.SkipPresubmit;
53 
54 import java.net.Inet4Address;
55 import java.net.Inet6Address;
56 import java.net.InetAddress;
57 import java.util.ArrayList;
58 import java.util.List;
59 import java.util.concurrent.CountDownLatch;
60 import java.util.concurrent.Executor;
61 import java.util.concurrent.TimeUnit;
62 
63 @AppModeFull(reason = "WRITE_SECURE_SETTINGS permission can't be granted to instant apps")
64 public class DnsResolverTest extends AndroidTestCase {
65     private static final String TAG = "DnsResolverTest";
66     private static final char[] HEX_CHARS = {
67             '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
68     };
69 
70     static final String TEST_DOMAIN = "www.google.com";
71     static final String TEST_NX_DOMAIN = "test1-nx.metric.gstatic.com";
72     static final String INVALID_PRIVATE_DNS_SERVER = "invalid.google";
73     static final String GOOGLE_PRIVATE_DNS_SERVER = "dns.google";
74     static final byte[] TEST_BLOB = new byte[]{
75             /* Header */
76             0x55, 0x66, /* Transaction ID */
77             0x01, 0x00, /* Flags */
78             0x00, 0x01, /* Questions */
79             0x00, 0x00, /* Answer RRs */
80             0x00, 0x00, /* Authority RRs */
81             0x00, 0x00, /* Additional RRs */
82             /* Queries */
83             0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6F, 0x6F, 0x67, 0x6c, 0x65,
84             0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */
85             0x00, 0x01, /* Type */
86             0x00, 0x01  /* Class */
87     };
88     static final int TIMEOUT_MS = 12_000;
89     static final int CANCEL_TIMEOUT_MS = 3_000;
90     static final int CANCEL_RETRY_TIMES = 5;
91     static final int QUERY_TIMES = 10;
92     static final int NXDOMAIN = 3;
93 
94     private ContentResolver mCR;
95     private ConnectivityManager mCM;
96     private PackageManager mPackageManager;
97     private CtsNetUtils mCtsNetUtils;
98     private Executor mExecutor;
99     private Executor mExecutorInline;
100     private DnsResolver mDns;
101 
102     private String mOldMode;
103     private String mOldDnsSpecifier;
104     private TestNetworkCallback mWifiRequestCallback = null;
105 
106     @Override
setUp()107     protected void setUp() throws Exception {
108         super.setUp();
109         mCM = (ConnectivityManager) getContext().getSystemService(Context.CONNECTIVITY_SERVICE);
110         mDns = DnsResolver.getInstance();
111         mExecutor = new Handler(Looper.getMainLooper())::post;
112         mExecutorInline = (Runnable r) -> r.run();
113         mCR = getContext().getContentResolver();
114         mCtsNetUtils = new CtsNetUtils(getContext());
115         mCtsNetUtils.storePrivateDnsSetting();
116         mPackageManager = mContext.getPackageManager();
117     }
118 
119     @Override
tearDown()120     protected void tearDown() throws Exception {
121         mCtsNetUtils.restorePrivateDnsSetting();
122         if (mWifiRequestCallback != null) {
123             mCM.unregisterNetworkCallback(mWifiRequestCallback);
124         }
125         super.tearDown();
126     }
127 
byteArrayToHexString(byte[] bytes)128     private static String byteArrayToHexString(byte[] bytes) {
129         char[] hexChars = new char[bytes.length * 2];
130         for (int i = 0; i < bytes.length; ++i) {
131             int b = bytes[i] & 0xFF;
132             hexChars[i * 2] = HEX_CHARS[b >>> 4];
133             hexChars[i * 2 + 1] = HEX_CHARS[b & 0x0F];
134         }
135         return new String(hexChars);
136     }
137 
getTestableNetworks()138     private Network[] getTestableNetworks() {
139         if (mPackageManager.hasSystemFeature(PackageManager.FEATURE_WIFI)) {
140             // File a NetworkRequest for Wi-Fi, so it connects even if a higher-scoring
141             // network, such as Ethernet, is already connected.
142             final NetworkRequest request = new NetworkRequest.Builder()
143                     .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
144                     .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
145                     .build();
146             mWifiRequestCallback = new TestNetworkCallback();
147             mCM.requestNetwork(request, mWifiRequestCallback);
148             mCtsNetUtils.ensureWifiConnected();
149         }
150         final ArrayList<Network> testableNetworks = new ArrayList<Network>();
151         for (Network network : mCM.getAllNetworks()) {
152             final NetworkCapabilities nc = mCM.getNetworkCapabilities(network);
153             if (nc != null
154                     && nc.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
155                     && nc.hasCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)) {
156                 testableNetworks.add(network);
157             }
158         }
159 
160         assertTrue(
161                 "This test requires that at least one network be connected. " +
162                         "Please ensure that the device is connected to a network.",
163                 testableNetworks.size() >= 1);
164         // In order to test query with null network, add null as an element.
165         // Test cases which query with null network will go on default network.
166         testableNetworks.add(null);
167         return testableNetworks.toArray(new Network[0]);
168     }
169 
assertGreaterThan(String msg, int first, int second)170     static private void assertGreaterThan(String msg, int first, int second) {
171         assertTrue(msg + " Excepted " + first + " to be greater than " + second, first > second);
172     }
173 
174     private static class DnsParseException extends Exception {
DnsParseException(String msg)175         public DnsParseException(String msg) {
176             super(msg);
177         }
178     }
179 
180     private static class DnsAnswer extends DnsPacket {
DnsAnswer(@onNull byte[] data)181         DnsAnswer(@NonNull byte[] data) throws DnsParseException {
182             super(data);
183 
184             // Check QR field.(query (0), or a response (1)).
185             if ((mHeader.flags & (1 << 15)) == 0) {
186                 throw new DnsParseException("Not an answer packet");
187             }
188         }
189 
getRcode()190         int getRcode() {
191             return mHeader.rcode;
192         }
193 
getANCount()194         int getANCount() {
195             return mHeader.getRecordCount(ANSECTION);
196         }
197 
getQDCount()198         int getQDCount() {
199             return mHeader.getRecordCount(QDSECTION);
200         }
201     }
202 
203     /**
204      * A query callback that ensures that the query is cancelled and that onAnswer is never
205      * called. If the query succeeds before it is cancelled, needRetry will return true so the
206      * test can retry.
207      */
208     class VerifyCancelCallback implements DnsResolver.Callback<byte[]> {
209         private final CountDownLatch mLatch = new CountDownLatch(1);
210         private final String mMsg;
211         private final CancellationSignal mCancelSignal;
212         private int mRcode;
213         private DnsAnswer mDnsAnswer;
214         private String mErrorMsg = null;
215 
VerifyCancelCallback(@onNull String msg, @Nullable CancellationSignal cancel)216         VerifyCancelCallback(@NonNull String msg, @Nullable CancellationSignal cancel) {
217             mMsg = msg;
218             mCancelSignal = cancel;
219         }
220 
VerifyCancelCallback(@onNull String msg)221         VerifyCancelCallback(@NonNull String msg) {
222             this(msg, null);
223         }
224 
waitForAnswer(int timeout)225         public boolean waitForAnswer(int timeout) throws InterruptedException {
226             return mLatch.await(timeout, TimeUnit.MILLISECONDS);
227         }
228 
waitForAnswer()229         public boolean waitForAnswer() throws InterruptedException {
230             return waitForAnswer(TIMEOUT_MS);
231         }
232 
needRetry()233         public boolean needRetry() throws InterruptedException {
234             return mLatch.await(CANCEL_TIMEOUT_MS, TimeUnit.MILLISECONDS);
235         }
236 
237         @Override
onAnswer(@onNull byte[] answer, int rcode)238         public void onAnswer(@NonNull byte[] answer, int rcode) {
239             if (mCancelSignal != null && mCancelSignal.isCanceled()) {
240                 mErrorMsg = mMsg + " should not have returned any answers";
241                 mLatch.countDown();
242                 return;
243             }
244 
245             mRcode = rcode;
246             try {
247                 mDnsAnswer = new DnsAnswer(answer);
248             } catch (ParseException | DnsParseException e) {
249                 mErrorMsg = mMsg + e.getMessage();
250                 mLatch.countDown();
251                 return;
252             }
253             Log.d(TAG, "Reported blob: " + byteArrayToHexString(answer));
254             mLatch.countDown();
255         }
256 
257         @Override
onError(@onNull DnsResolver.DnsException error)258         public void onError(@NonNull DnsResolver.DnsException error) {
259             mErrorMsg = mMsg + error.getMessage();
260             mLatch.countDown();
261         }
262 
assertValidAnswer()263         private void assertValidAnswer() {
264             assertNull(mErrorMsg);
265             assertNotNull(mMsg + " No valid answer", mDnsAnswer);
266             assertEquals(mMsg + " Unexpected error: reported rcode" + mRcode +
267                     " blob's rcode " + mDnsAnswer.getRcode(), mRcode, mDnsAnswer.getRcode());
268         }
269 
assertHasAnswer()270         public void assertHasAnswer() {
271             assertValidAnswer();
272             // Check rcode field.(0, No error condition).
273             assertEquals(mMsg + " Response error, rcode: " + mRcode, mRcode, 0);
274             // Check answer counts.
275             assertGreaterThan(mMsg + " No answer found", mDnsAnswer.getANCount(), 0);
276             // Check question counts.
277             assertGreaterThan(mMsg + " No question found", mDnsAnswer.getQDCount(), 0);
278         }
279 
assertNXDomain()280         public void assertNXDomain() {
281             assertValidAnswer();
282             // Check rcode field.(3, NXDomain).
283             assertEquals(mMsg + " Unexpected rcode: " + mRcode, mRcode, NXDOMAIN);
284             // Check answer counts. Expect 0 answer.
285             assertEquals(mMsg + " Not an empty answer", mDnsAnswer.getANCount(), 0);
286             // Check question counts.
287             assertGreaterThan(mMsg + " No question found", mDnsAnswer.getQDCount(), 0);
288         }
289 
assertEmptyAnswer()290         public void assertEmptyAnswer() {
291             assertValidAnswer();
292             // Check rcode field.(0, No error condition).
293             assertEquals(mMsg + " Response error, rcode: " + mRcode, mRcode, 0);
294             // Check answer counts. Expect 0 answer.
295             assertEquals(mMsg + " Not an empty answer", mDnsAnswer.getANCount(), 0);
296             // Check question counts.
297             assertGreaterThan(mMsg + " No question found", mDnsAnswer.getQDCount(), 0);
298         }
299     }
300 
testRawQuery()301     public void testRawQuery() throws Exception {
302         doTestRawQuery(mExecutor);
303     }
304 
testRawQueryInline()305     public void testRawQueryInline() throws Exception {
306         doTestRawQuery(mExecutorInline);
307     }
308 
testRawQueryBlob()309     public void testRawQueryBlob() throws Exception {
310         doTestRawQueryBlob(mExecutor);
311     }
312 
testRawQueryBlobInline()313     public void testRawQueryBlobInline() throws Exception {
314         doTestRawQueryBlob(mExecutorInline);
315     }
316 
testRawQueryRoot()317     public void testRawQueryRoot() throws Exception {
318         doTestRawQueryRoot(mExecutor);
319     }
320 
testRawQueryRootInline()321     public void testRawQueryRootInline() throws Exception {
322         doTestRawQueryRoot(mExecutorInline);
323     }
324 
testRawQueryNXDomain()325     public void testRawQueryNXDomain() throws Exception {
326         doTestRawQueryNXDomain(mExecutor);
327     }
328 
testRawQueryNXDomainInline()329     public void testRawQueryNXDomainInline() throws Exception {
330         doTestRawQueryNXDomain(mExecutorInline);
331     }
332 
testRawQueryNXDomainWithPrivateDns()333     public void testRawQueryNXDomainWithPrivateDns() throws Exception {
334         doTestRawQueryNXDomainWithPrivateDns(mExecutor);
335     }
336 
testRawQueryNXDomainInlineWithPrivateDns()337     public void testRawQueryNXDomainInlineWithPrivateDns() throws Exception {
338         doTestRawQueryNXDomainWithPrivateDns(mExecutorInline);
339     }
340 
doTestRawQuery(Executor executor)341     public void doTestRawQuery(Executor executor) throws InterruptedException {
342         final String msg = "RawQuery " + TEST_DOMAIN;
343         for (Network network : getTestableNetworks()) {
344             final VerifyCancelCallback callback = new VerifyCancelCallback(msg);
345             mDns.rawQuery(network, TEST_DOMAIN, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
346                     executor, null, callback);
347 
348             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
349                     callback.waitForAnswer());
350             callback.assertHasAnswer();
351         }
352     }
353 
doTestRawQueryBlob(Executor executor)354     public void doTestRawQueryBlob(Executor executor) throws InterruptedException {
355         final byte[] blob = new byte[]{
356                 /* Header */
357                 0x55, 0x66, /* Transaction ID */
358                 0x01, 0x00, /* Flags */
359                 0x00, 0x01, /* Questions */
360                 0x00, 0x00, /* Answer RRs */
361                 0x00, 0x00, /* Authority RRs */
362                 0x00, 0x00, /* Additional RRs */
363                 /* Queries */
364                 0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6F, 0x6F, 0x67, 0x6c, 0x65,
365                 0x03, 0x63, 0x6f, 0x6d, 0x00, /* Name */
366                 0x00, 0x01, /* Type */
367                 0x00, 0x01  /* Class */
368         };
369         final String msg = "RawQuery blob " + byteArrayToHexString(blob);
370         for (Network network : getTestableNetworks()) {
371             final VerifyCancelCallback callback = new VerifyCancelCallback(msg);
372             mDns.rawQuery(network, blob, FLAG_NO_CACHE_LOOKUP, executor, null, callback);
373 
374             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
375                     callback.waitForAnswer());
376             callback.assertHasAnswer();
377         }
378     }
379 
doTestRawQueryRoot(Executor executor)380     public void doTestRawQueryRoot(Executor executor) throws InterruptedException {
381         final String dname = "";
382         final String msg = "RawQuery empty dname(ROOT) ";
383         for (Network network : getTestableNetworks()) {
384             final VerifyCancelCallback callback = new VerifyCancelCallback(msg);
385             mDns.rawQuery(network, dname, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
386                     executor, null, callback);
387 
388             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
389                     callback.waitForAnswer());
390             // Except no answer record because the root does not have AAAA records.
391             callback.assertEmptyAnswer();
392         }
393     }
394 
doTestRawQueryNXDomain(Executor executor)395     public void doTestRawQueryNXDomain(Executor executor) throws InterruptedException {
396         final String msg = "RawQuery " + TEST_NX_DOMAIN;
397 
398         for (Network network : getTestableNetworks()) {
399             final NetworkCapabilities nc = (network != null)
400                     ? mCM.getNetworkCapabilities(network)
401                     : mCM.getNetworkCapabilities(mCM.getActiveNetwork());
402             assertNotNull("Couldn't determine NetworkCapabilities for " + network, nc);
403             // Some cellular networks configure their DNS servers never to return NXDOMAIN, so don't
404             // test NXDOMAIN on these DNS servers.
405             // b/144521720
406             if (nc.hasTransport(TRANSPORT_CELLULAR)) continue;
407             final VerifyCancelCallback callback = new VerifyCancelCallback(msg);
408             mDns.rawQuery(network, TEST_NX_DOMAIN, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
409                     executor, null, callback);
410 
411             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
412                     callback.waitForAnswer());
413             callback.assertNXDomain();
414         }
415     }
416 
doTestRawQueryNXDomainWithPrivateDns(Executor executor)417     public void doTestRawQueryNXDomainWithPrivateDns(Executor executor)
418             throws InterruptedException {
419         final String msg = "RawQuery " + TEST_NX_DOMAIN + " with private DNS";
420         // Enable private DNS strict mode and set server to dns.google before doing NxDomain test.
421         // b/144521720
422         mCtsNetUtils.setPrivateDnsStrictMode(GOOGLE_PRIVATE_DNS_SERVER);
423         for (Network network :  getTestableNetworks()) {
424             final Network networkForPrivateDns =
425                     (network != null) ? network : mCM.getActiveNetwork();
426             assertNotNull("Can't find network to await private DNS on", networkForPrivateDns);
427             mCtsNetUtils.awaitPrivateDnsSetting(msg + " wait private DNS setting timeout",
428                     networkForPrivateDns, GOOGLE_PRIVATE_DNS_SERVER, true);
429             final VerifyCancelCallback callback = new VerifyCancelCallback(msg);
430             mDns.rawQuery(network, TEST_NX_DOMAIN, CLASS_IN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
431                     executor, null, callback);
432 
433             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
434                     callback.waitForAnswer());
435             callback.assertNXDomain();
436         }
437     }
438 
testRawQueryCancel()439     public void testRawQueryCancel() throws InterruptedException {
440         final String msg = "Test cancel RawQuery " + TEST_DOMAIN;
441         // Start a DNS query and the cancel it immediately. Use VerifyCancelCallback to expect
442         // that the query is cancelled before it succeeds. If it is not cancelled before it
443         // succeeds, retry the test until it is.
444         for (Network network : getTestableNetworks()) {
445             boolean retry = false;
446             int round = 0;
447             do {
448                 if (++round > CANCEL_RETRY_TIMES) {
449                     fail(msg + " cancel failed " + CANCEL_RETRY_TIMES + " times");
450                 }
451                 final CountDownLatch latch = new CountDownLatch(1);
452                 final CancellationSignal cancelSignal = new CancellationSignal();
453                 final VerifyCancelCallback callback = new VerifyCancelCallback(msg, cancelSignal);
454                 mDns.rawQuery(network, TEST_DOMAIN, CLASS_IN, TYPE_AAAA, FLAG_EMPTY,
455                         mExecutor, cancelSignal, callback);
456                 mExecutor.execute(() -> {
457                     cancelSignal.cancel();
458                     latch.countDown();
459                 });
460 
461                 retry = callback.needRetry();
462                 assertTrue(msg + " query was not cancelled",
463                         latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
464             } while (retry);
465         }
466     }
467 
testRawQueryBlobCancel()468     public void testRawQueryBlobCancel() throws InterruptedException {
469         final String msg = "Test cancel RawQuery blob " + byteArrayToHexString(TEST_BLOB);
470         // Start a DNS query and the cancel it immediately. Use VerifyCancelCallback to expect
471         // that the query is cancelled before it succeeds. If it is not cancelled before it
472         // succeeds, retry the test until it is.
473         for (Network network : getTestableNetworks()) {
474             boolean retry = false;
475             int round = 0;
476             do {
477                 if (++round > CANCEL_RETRY_TIMES) {
478                     fail(msg + " cancel failed " + CANCEL_RETRY_TIMES + " times");
479                 }
480                 final CountDownLatch latch = new CountDownLatch(1);
481                 final CancellationSignal cancelSignal = new CancellationSignal();
482                 final VerifyCancelCallback callback = new VerifyCancelCallback(msg, cancelSignal);
483                 mDns.rawQuery(network, TEST_BLOB, FLAG_EMPTY, mExecutor, cancelSignal, callback);
484                 mExecutor.execute(() -> {
485                     cancelSignal.cancel();
486                     latch.countDown();
487                 });
488 
489                 retry = callback.needRetry();
490                 assertTrue(msg + " cancel is not done",
491                         latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
492             } while (retry);
493         }
494     }
495 
testCancelBeforeQuery()496     public void testCancelBeforeQuery() throws InterruptedException {
497         final String msg = "Test cancelled RawQuery " + TEST_DOMAIN;
498         for (Network network : getTestableNetworks()) {
499             final VerifyCancelCallback callback = new VerifyCancelCallback(msg);
500             final CancellationSignal cancelSignal = new CancellationSignal();
501             cancelSignal.cancel();
502             mDns.rawQuery(network, TEST_DOMAIN, CLASS_IN, TYPE_AAAA, FLAG_EMPTY,
503                     mExecutor, cancelSignal, callback);
504 
505             assertTrue(msg + " should not return any answers",
506                     !callback.waitForAnswer(CANCEL_TIMEOUT_MS));
507         }
508     }
509 
510     /**
511      * A query callback for InetAddress that ensures that the query is
512      * cancelled and that onAnswer is never called. If the query succeeds
513      * before it is cancelled, needRetry will return true so the
514      * test can retry.
515      */
516     class VerifyCancelInetAddressCallback implements DnsResolver.Callback<List<InetAddress>> {
517         private final CountDownLatch mLatch = new CountDownLatch(1);
518         private final String mMsg;
519         private final List<InetAddress> mAnswers;
520         private final CancellationSignal mCancelSignal;
521         private String mErrorMsg = null;
522 
VerifyCancelInetAddressCallback(@onNull String msg, @Nullable CancellationSignal cancel)523         VerifyCancelInetAddressCallback(@NonNull String msg, @Nullable CancellationSignal cancel) {
524             this.mMsg = msg;
525             this.mCancelSignal = cancel;
526             mAnswers = new ArrayList<>();
527         }
528 
waitForAnswer()529         public boolean waitForAnswer() throws InterruptedException {
530             return mLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
531         }
532 
needRetry()533         public boolean needRetry() throws InterruptedException {
534             return mLatch.await(CANCEL_TIMEOUT_MS, TimeUnit.MILLISECONDS);
535         }
536 
isAnswerEmpty()537         public boolean isAnswerEmpty() {
538             return mAnswers.isEmpty();
539         }
540 
hasIpv6Answer()541         public boolean hasIpv6Answer() {
542             for (InetAddress answer : mAnswers) {
543                 if (answer instanceof Inet6Address) return true;
544             }
545             return false;
546         }
547 
hasIpv4Answer()548         public boolean hasIpv4Answer() {
549             for (InetAddress answer : mAnswers) {
550                 if (answer instanceof Inet4Address) return true;
551             }
552             return false;
553         }
554 
assertNoError()555         public void assertNoError() {
556             assertNull(mErrorMsg);
557         }
558 
559         @Override
onAnswer(@onNull List<InetAddress> answerList, int rcode)560         public void onAnswer(@NonNull List<InetAddress> answerList, int rcode) {
561             if (mCancelSignal != null && mCancelSignal.isCanceled()) {
562                 mErrorMsg = mMsg + " should not have returned any answers";
563                 mLatch.countDown();
564                 return;
565             }
566             for (InetAddress addr : answerList) {
567                 Log.d(TAG, "Reported addr: " + addr.toString());
568             }
569             mAnswers.clear();
570             mAnswers.addAll(answerList);
571             mLatch.countDown();
572         }
573 
574         @Override
onError(@onNull DnsResolver.DnsException error)575         public void onError(@NonNull DnsResolver.DnsException error) {
576             mErrorMsg = mMsg + error.getMessage();
577             mLatch.countDown();
578         }
579     }
580 
testQueryForInetAddress()581     public void testQueryForInetAddress() throws Exception {
582         doTestQueryForInetAddress(mExecutor);
583     }
584 
testQueryForInetAddressInline()585     public void testQueryForInetAddressInline() throws Exception {
586         doTestQueryForInetAddress(mExecutorInline);
587     }
588 
testQueryForInetAddressIpv4()589     public void testQueryForInetAddressIpv4() throws Exception {
590         doTestQueryForInetAddressIpv4(mExecutor);
591     }
592 
testQueryForInetAddressIpv4Inline()593     public void testQueryForInetAddressIpv4Inline() throws Exception {
594         doTestQueryForInetAddressIpv4(mExecutorInline);
595     }
596 
testQueryForInetAddressIpv6()597     public void testQueryForInetAddressIpv6() throws Exception {
598         doTestQueryForInetAddressIpv6(mExecutor);
599     }
600 
testQueryForInetAddressIpv6Inline()601     public void testQueryForInetAddressIpv6Inline() throws Exception {
602         doTestQueryForInetAddressIpv6(mExecutorInline);
603     }
604 
testContinuousQueries()605     public void testContinuousQueries() throws Exception {
606         doTestContinuousQueries(mExecutor);
607     }
608 
609     @SkipPresubmit(reason = "Flaky: b/159762682; add to presubmit after fixing")
testContinuousQueriesInline()610     public void testContinuousQueriesInline() throws Exception {
611         doTestContinuousQueries(mExecutorInline);
612     }
613 
doTestQueryForInetAddress(Executor executor)614     public void doTestQueryForInetAddress(Executor executor) throws InterruptedException {
615         final String msg = "Test query for InetAddress " + TEST_DOMAIN;
616         for (Network network : getTestableNetworks()) {
617             final VerifyCancelInetAddressCallback callback =
618                     new VerifyCancelInetAddressCallback(msg, null);
619             mDns.query(network, TEST_DOMAIN, FLAG_NO_CACHE_LOOKUP, executor, null, callback);
620 
621             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
622                     callback.waitForAnswer());
623             callback.assertNoError();
624             assertTrue(msg + " returned 0 results", !callback.isAnswerEmpty());
625         }
626     }
627 
testQueryCancelForInetAddress()628     public void testQueryCancelForInetAddress() throws InterruptedException {
629         final String msg = "Test cancel query for InetAddress " + TEST_DOMAIN;
630         // Start a DNS query and the cancel it immediately. Use VerifyCancelInetAddressCallback to
631         // expect that the query is cancelled before it succeeds. If it is not cancelled before it
632         // succeeds, retry the test until it is.
633         for (Network network : getTestableNetworks()) {
634             boolean retry = false;
635             int round = 0;
636             do {
637                 if (++round > CANCEL_RETRY_TIMES) {
638                     fail(msg + " cancel failed " + CANCEL_RETRY_TIMES + " times");
639                 }
640                 final CountDownLatch latch = new CountDownLatch(1);
641                 final CancellationSignal cancelSignal = new CancellationSignal();
642                 final VerifyCancelInetAddressCallback callback =
643                         new VerifyCancelInetAddressCallback(msg, cancelSignal);
644                 mDns.query(network, TEST_DOMAIN, FLAG_EMPTY, mExecutor, cancelSignal, callback);
645                 mExecutor.execute(() -> {
646                     cancelSignal.cancel();
647                     latch.countDown();
648                 });
649 
650                 retry = callback.needRetry();
651                 assertTrue(msg + " query was not cancelled",
652                         latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
653             } while (retry);
654         }
655     }
656 
doTestQueryForInetAddressIpv4(Executor executor)657     public void doTestQueryForInetAddressIpv4(Executor executor) throws InterruptedException {
658         final String msg = "Test query for IPv4 InetAddress " + TEST_DOMAIN;
659         for (Network network : getTestableNetworks()) {
660             final VerifyCancelInetAddressCallback callback =
661                     new VerifyCancelInetAddressCallback(msg, null);
662             mDns.query(network, TEST_DOMAIN, TYPE_A, FLAG_NO_CACHE_LOOKUP,
663                     executor, null, callback);
664 
665             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
666                     callback.waitForAnswer());
667             callback.assertNoError();
668             assertTrue(msg + " returned 0 results", !callback.isAnswerEmpty());
669             assertTrue(msg + " returned Ipv6 results", !callback.hasIpv6Answer());
670         }
671     }
672 
doTestQueryForInetAddressIpv6(Executor executor)673     public void doTestQueryForInetAddressIpv6(Executor executor) throws InterruptedException {
674         final String msg = "Test query for IPv6 InetAddress " + TEST_DOMAIN;
675         for (Network network : getTestableNetworks()) {
676             final VerifyCancelInetAddressCallback callback =
677                     new VerifyCancelInetAddressCallback(msg, null);
678             mDns.query(network, TEST_DOMAIN, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
679                     executor, null, callback);
680 
681             assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
682                     callback.waitForAnswer());
683             callback.assertNoError();
684             assertTrue(msg + " returned 0 results", !callback.isAnswerEmpty());
685             assertTrue(msg + " returned Ipv4 results", !callback.hasIpv4Answer());
686         }
687     }
688 
testPrivateDnsBypass()689     public void testPrivateDnsBypass() throws InterruptedException {
690         final Network[] testNetworks = getTestableNetworks();
691 
692         // Set an invalid private DNS server
693         mCtsNetUtils.setPrivateDnsStrictMode(INVALID_PRIVATE_DNS_SERVER);
694         final String msg = "Test PrivateDnsBypass " + TEST_DOMAIN;
695         for (Network network : testNetworks) {
696             // This test cannot be ran with null network because we need to explicitly pass a
697             // private DNS bypassable network or bind one.
698             if (network == null) continue;
699 
700             // wait for private DNS setting propagating
701             mCtsNetUtils.awaitPrivateDnsSetting(msg + " wait private DNS setting timeout",
702                     network, INVALID_PRIVATE_DNS_SERVER, false);
703 
704             final CountDownLatch latch = new CountDownLatch(1);
705             final DnsResolver.Callback<List<InetAddress>> errorCallback =
706                     new DnsResolver.Callback<List<InetAddress>>() {
707                         @Override
708                         public void onAnswer(@NonNull List<InetAddress> answerList, int rcode) {
709                             fail(msg + " should not get valid answer");
710                         }
711 
712                         @Override
713                         public void onError(@NonNull DnsResolver.DnsException error) {
714                             assertEquals(DnsResolver.ERROR_SYSTEM, error.code);
715                             assertEquals(ETIMEDOUT, ((ErrnoException) error.getCause()).errno);
716                             latch.countDown();
717                         }
718                     };
719             // Private DNS strict mode with invalid DNS server is set
720             // Expect no valid answer returned but ErrnoException with ETIMEDOUT
721             mDns.query(network, TEST_DOMAIN, FLAG_NO_CACHE_LOOKUP, mExecutor, null, errorCallback);
722 
723             assertTrue(msg + " invalid server round. No response after " + TIMEOUT_MS + "ms.",
724                     latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
725 
726             final VerifyCancelInetAddressCallback callback =
727                     new VerifyCancelInetAddressCallback(msg, null);
728             // Bypass privateDns, expect query works fine
729             mDns.query(network.getPrivateDnsBypassingCopy(),
730                     TEST_DOMAIN, FLAG_NO_CACHE_LOOKUP, mExecutor, null, callback);
731 
732             assertTrue(msg + " bypass private DNS round. No answer after " + TIMEOUT_MS + "ms.",
733                     callback.waitForAnswer());
734             callback.assertNoError();
735             assertTrue(msg + " returned 0 results", !callback.isAnswerEmpty());
736 
737             // To ensure private DNS bypass still work even if passing null network.
738             // Bind process network with a private DNS bypassable network.
739             mCM.bindProcessToNetwork(network.getPrivateDnsBypassingCopy());
740             final VerifyCancelInetAddressCallback callbackWithNullNetwork =
741                     new VerifyCancelInetAddressCallback(msg + " with null network ", null);
742             mDns.query(null,
743                     TEST_DOMAIN, FLAG_NO_CACHE_LOOKUP, mExecutor, null, callbackWithNullNetwork);
744 
745             assertTrue(msg + " with null network bypass private DNS round. No answer after " +
746                     TIMEOUT_MS + "ms.", callbackWithNullNetwork.waitForAnswer());
747             callbackWithNullNetwork.assertNoError();
748             assertTrue(msg + " with null network returned 0 results",
749                     !callbackWithNullNetwork.isAnswerEmpty());
750 
751             // Reset process network to default.
752             mCM.bindProcessToNetwork(null);
753         }
754     }
755 
doTestContinuousQueries(Executor executor)756     public void doTestContinuousQueries(Executor executor) throws InterruptedException {
757         final String msg = "Test continuous " + QUERY_TIMES + " queries " + TEST_DOMAIN;
758         for (Network network : getTestableNetworks()) {
759             for (int i = 0; i < QUERY_TIMES ; ++i) {
760                 final VerifyCancelInetAddressCallback callback =
761                         new VerifyCancelInetAddressCallback(msg, null);
762                 // query v6/v4 in turn
763                 boolean queryV6 = (i % 2 == 0);
764                 mDns.query(network, TEST_DOMAIN, queryV6 ? TYPE_AAAA : TYPE_A,
765                         FLAG_NO_CACHE_LOOKUP, executor, null, callback);
766 
767                 assertTrue(msg + " but no answer after " + TIMEOUT_MS + "ms.",
768                         callback.waitForAnswer());
769                 callback.assertNoError();
770                 assertTrue(msg + " returned 0 results", !callback.isAnswerEmpty());
771                 assertTrue(msg + " returned " + (queryV6 ? "Ipv4" : "Ipv6") + " results",
772                         queryV6 ? !callback.hasIpv4Answer() : !callback.hasIpv6Answer());
773             }
774         }
775     }
776 }
777