1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.pm.test.parsing.parcelling
18 
19 import android.os.Parcel
20 import android.os.Parcelable
21 import com.android.server.pm.test.util.IgnoreableExpect
22 import com.google.common.truth.Expect
23 import org.junit.Before
24 import org.junit.Rule
25 import org.junit.Test
26 import org.junit.rules.TestRule
27 import java.util.Objects
28 import kotlin.contracts.ExperimentalContracts
29 import kotlin.reflect.KClass
30 import kotlin.reflect.KFunction
31 import kotlin.reflect.KFunction1
32 import kotlin.reflect.KFunction2
33 import kotlin.reflect.KFunction3
34 import kotlin.reflect.KVisibility
35 import kotlin.reflect.full.allSuperclasses
36 import kotlin.reflect.full.createInstance
37 import kotlin.reflect.full.isSubclassOf
38 import kotlin.reflect.full.memberFunctions
39 import kotlin.reflect.full.memberProperties
40 import kotlin.reflect.full.staticProperties
41 import kotlin.reflect.jvm.jvmErasure
42 
43 
44 @ExperimentalContracts
45 abstract class ParcelableComponentTest(
46     private val getterType: KClass<*>,
47     private val setterType: KClass<out Parcelable>
48 ) {
49 
50     companion object {
51         private val DEFAULT_EXCLUDED = listOf(
52             // Java
53             "toString",
54             "equals",
55             "hashCode",
56             // Parcelable
57             "getStability",
58             "describeContents",
59             "writeToParcel",
60             // @DataClass
61             "__metadata"
62         )
63     }
64 
65     internal val ignoreableExpect = IgnoreableExpect()
66 
67     // Hides internal type
68     @get:Rule
69     val ignoreableAsTestRule: TestRule = ignoreableExpect
70 
71     val expect: Expect
72         get() = ignoreableExpect.expect
73 
74     protected var testCounter = 1
75 
76     protected abstract val defaultImpl: Any
77     protected abstract val creator: Parcelable.Creator<out Parcelable>
78 
79     protected open val excludedMethods: Collection<String> = emptyList()
80 
81     protected abstract val baseParams: Collection<KFunction1<*, Any?>>
82 
83     private val getters = getterType.memberFunctions
84         .filterNot { DEFAULT_EXCLUDED.contains(it.name) }
85 
86     private val setters = setterType.memberFunctions
87         .filterNot { DEFAULT_EXCLUDED.contains(it.name) }
88 
89     constructor(kClass: KClass<out Parcelable>) : this(kClass, kClass)
90 
91     @Before
92     fun checkNoPublicFields() {
93         // Fields are not currently testable, and the idea is to enforce interface access for
94         // immutability purposes, so disallow any public fields from existing.
95         expect.that(getterType.memberProperties.filter { it.visibility == KVisibility.PUBLIC }
96             .filterNot { DEFAULT_EXCLUDED.contains(it.name) })
97             .isEmpty()
98     }
99 
100     @Suppress("UNCHECKED_CAST")
101     private fun <ObjectType, ReturnType> buildParams(
102         getFunction: KFunction1<ObjectType, ReturnType>,
103     ): Param? {
104         return buildParams<ObjectType, ReturnType, ReturnType, ReturnType>(
105             getFunction,
106             autoValue(getFunction) as ReturnType ?: return null
107         )
108     }
109 
110     @Suppress("UNCHECKED_CAST")
111     private fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> buildParams(
112         getFunction: KFunction1<ObjectType, ReturnType>,
113         value: SetType,
114     ): Param? {
115         return getSetByValue<ObjectType, ReturnType, SetType, Any?>(
116             getFunction,
117             findSetFunction(getFunction) ?: return null,
118             value
119         )
120     }
121 
122     @Suppress("UNCHECKED_CAST")
123     private fun <ObjectType, ReturnType> findSetFunction(
124         getFunction: KFunction1<ObjectType, ReturnType>
125     ): KFunction2<ObjectType, ReturnType, Any?>? {
126         val getFunctionName = getFunction.name
127         val prefix = when {
128             getFunctionName.startsWith("get") -> "get"
129             getFunctionName.startsWith("is") -> "is"
130             getFunctionName.startsWith("has") -> "has"
131             else -> throw IllegalArgumentException("Unsupported method name $getFunctionName")
132         }
133         val setFunctionName = "set" + getFunctionName.removePrefix(prefix)
134         val setFunction = setters.filter { it.name == setFunctionName }
135             .minByOrNull { it.parameters.size }
136 
137         if (setFunction == null) {
138             expect.withMessage("$getFunctionName does not have corresponding $setFunctionName")
139                 .fail()
140             return null
141         }
142 
143         return setFunction as KFunction2<ObjectType, ReturnType, Any?>
144     }
145 
146     @Suppress("UNCHECKED_CAST")
147     private fun <ObjectType, ReturnType, SetType> findAddFunction(
148         getFunction: KFunction1<ObjectType, ReturnType>
149     ): KFunction2<ObjectType, SetType, Any?>? {
150         val getFunctionName = getFunction.name
151         if (!getFunctionName.startsWith("get")) {
152             throw IllegalArgumentException("Unsupported method name $getFunctionName")
153         }
154 
155         val setFunctionName = "add" + getFunctionName.removePrefix("get").run {
156             // Remove plurality
157             when {
158                 endsWith("ies") -> "${removeSuffix("ies")}y"
159                 endsWith("s") -> removeSuffix("s")
160                 else -> this
161             }
162         }
163 
164         val setFunction = setters.filter { it.name == setFunctionName }
165             .minByOrNull { it.parameters.size }
166 
167         if (setFunction == null) {
168             expect.withMessage("$getFunctionName does not have corresponding $setFunctionName")
169                 .fail()
170             return null
171         }
172 
173         return setFunction as KFunction2<ObjectType, SetType, Any?>
174     }
175 
176     protected fun <ObjectType, ReturnType> getter(
177         getFunction: KFunction1<ObjectType, ReturnType>,
178         valueToSet: ReturnType
179     ) = buildParams<ObjectType, ReturnType, ReturnType, ReturnType>(getFunction, valueToSet)
180 
181     protected fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> getter(
182         getFunction: KFunction1<ObjectType, ReturnType>,
183         expectedValue: CompareType,
184         valueToSet: SetType
185     ): Param? {
186         return getSetByValue(
187             getFunction,
188             findSetFunction(getFunction) ?: return null,
189             value = expectedValue,
190             transformSet = { valueToSet }
191         )
192     }
193 
194     @Suppress("UNCHECKED_CAST")
195     protected fun <ObjectType, ReturnType> adder(
196         getFunction: KFunction1<ObjectType, ReturnType>,
197         value: ReturnType,
198     ): Param? {
199         return getSetByValue(
200             getFunction,
201             findAddFunction<ObjectType, Any?, ReturnType>(getFunction) ?: return null,
202             value,
203             transformGet = {
204                 // Primitive arrays don't implement Iterable, so cast manually
205                 when (it) {
206                     is BooleanArray -> it.singleOrNull()
207                     is IntArray -> it.singleOrNull()
208                     is LongArray -> it.singleOrNull()
209                     is Iterable<*> -> it.singleOrNull()
210                     else -> null
211                 }
212             },
213         )
214     }
215 
216     /**
217      * Method to provide custom getter and setter logic for values which are not simple primitives
218      * or cannot be directly compared using [Objects.equals].
219      *
220      * @param getFunction the getter function which will be called and marked as tested
221      * @param setFunction the setter function which will be called and marked as tested
222      * @param value the value for comparison through the parcel-unparcel cycle, which can be
223      *          anything, like the [String] ID of an inner object
224      * @param transformGet the function to transform the result of [getFunction] into [value]
225      * @param transformSet the function to transform [value] into an input for [setFunction]
226      * @param compare the function that compares the pre/post-parcel [value] objects
227      */
228     @Suppress("UNCHECKED_CAST")
229     protected fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> getSetByValue(
230         getFunction: KFunction1<ObjectType, ReturnType>,
231         setFunction: KFunction2<ObjectType, SetType, Any?>,
232         value: CompareType,
233         transformGet: (ReturnType) -> CompareType = { it as CompareType },
234         transformSet: (CompareType) -> SetType = { it as SetType },
235         compare: (CompareType, CompareType) -> Boolean? = Objects::equals
236     ) = Param(
237         getFunction.name,
238         { transformGet(getFunction.call(it as ObjectType)) },
239         setFunction.name,
240         { setFunction.call(it.first() as ObjectType, transformSet(it[1] as CompareType)) },
241         { value },
242         { first, second -> compare(first as CompareType, second as CompareType) == true }
243     )
244 
245     /**
246      * Variant of [getSetByValue] that allows specifying a non-member [getFunction]. Mostly used
247      * for AndroidPackageHidden for APIs which are hidden from the interface.
248      */
249     @Suppress("UNCHECKED_CAST")
250     protected fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> getSetByValue(
251         getFunction: (ObjectType) -> ReturnType,
252         getFunctionName: String,
253         setFunction: KFunction2<ObjectType, SetType, Any?>,
254         value: CompareType,
255         transformGet: (ReturnType) -> CompareType = { it as CompareType },
256         transformSet: (CompareType) -> SetType = { it as SetType },
257         compare: (CompareType, CompareType) -> Boolean? = Objects::equals
258     ) = Param(
259         getFunctionName,
260         { transformGet(getFunction(it as ObjectType)) },
261         setFunction.name,
262         { setFunction.call(it.first() as ObjectType, transformSet(it[1] as CompareType)) },
263         { value },
264         { first, second -> compare(first as CompareType, second as CompareType) == true },
265         verifyFunctionName = false
266     )
267 
268     /**
269      * Variant of [getSetByValue] that allows specifying a [setFunction] with 2 inputs.
270      */
271     @Suppress("UNCHECKED_CAST")
272     protected fun <ObjectType, ReturnType, SetType1 : Any?, SetType2 : Any?, CompareType : Any?>
273             getSetByValue2(
274         getFunction: KFunction1<ObjectType, ReturnType>,
275         setFunction: KFunction3<ObjectType, SetType1, SetType2, Any>,
276         value: CompareType,
277         transformGet: (ReturnType) -> CompareType = { it as CompareType },
278         transformSet: (CompareType) -> Pair<SetType1, SetType2> =
279             { it as Pair<SetType1, SetType2> },
280         compare: (CompareType, CompareType) -> Boolean = Objects::equals
281     ) = Param(
282         getFunction.name,
283         { transformGet(getFunction.call(it as ObjectType)) },
284         setFunction.name,
285         {
286             val pair = transformSet(it[1] as CompareType)
287             setFunction.call(it.first() as ObjectType, pair.first, pair.second)
288         },
289         { value },
290         { first, second -> compare(first as CompareType, second as CompareType) }
291     )
292 
293     protected fun autoValue(getFunction: KFunction<*>) = when (getFunction.returnType.jvmErasure) {
294         Boolean::class -> (getFunction.call(defaultImpl) as Boolean?)?.not() ?: true
295         CharSequence::class,
296         String::class -> getFunction.name + "TEST"
297         Int::class -> testCounter++
298         Long::class -> (testCounter++).toLong()
299         Float::class -> (testCounter++).toFloat()
300         else -> {
301             expect.withMessage("${getFunction.name} needs to provide value").fail()
302             null
303         }
304     }
305 
306     /**
307      * Verifies two instances are equivalent via a series of properties. For use when a public API
308      * class has not implemented equals.
309      */
310     @Suppress("UNCHECKED_CAST")
311     protected fun <T : Any> equalBy(
312         first: T?,
313         second: T?,
314         vararg properties: (T) -> Any?
315     ) = properties.all { property ->
316         first?.let { property(it) } == second?.let { property(it) }
317     }
318 
319     fun buildBefore(): Pair<List<Param>, Parcelable> {
320         val params = baseParams.mapNotNull(::buildParams) + extraParams().filterNotNull()
321         val before = initialObject()
322 
323         params.forEach { it.setFunction(arrayOf(before, it.value())) }
324         finalizeObject(before)
325         return params to before
326     }
327 
328     protected open fun finalizeObject(parcelable: Parcelable) {
329     }
330 
331     @Test
332     fun valueComparison() {
333         val (params, before) = buildBefore()
334 
335         val parcel = Parcel.obtain()
336         writeToParcel(parcel, before)
337 
338         val dataSize = parcel.dataSize()
339 
340         parcel.setDataPosition(0)
341 
342         val baseline = initialObject()
343         finalizeObject(baseline)
344 
345         val baselineParcel = Parcel.obtain()
346         writeToParcel(baselineParcel, baseline)
347 
348         // Check that something substantial actually changed in the test object
349         expect.that(parcel.dataSize()).isGreaterThan(baselineParcel.dataSize())
350 
351         val after = creator.createFromParcel(parcel)
352 
353         expect.withMessage("Mismatched write and read data sizes")
354             .that(parcel.dataPosition())
355             .isEqualTo(dataSize)
356 
357         parcel.recycle()
358         baselineParcel.recycle()
359 
360         runAssertions(params, before, after)
361     }
362 
363     @Test
364     open fun parcellingSize() {
365         val parcelOne = Parcel.obtain()
366         writeToParcel(parcelOne, initialObject())
367 
368         val parcelTwo = Parcel.obtain()
369         initialObject().writeToParcel(parcelTwo, 0)
370 
371         val superDataSizes = setterType.allSuperclasses
372             .filter { it.isSubclassOf(Parcelable::class) }
373             .mapNotNull { it.memberFunctions.find { it.name == "writeToParcel" } }
374             .filter { it.isFinal }
375             .map {
376                 val parcel = Parcel.obtain()
377                 initialObject().writeToParcel(parcel, 0)
378                 parcel.dataSize().also { parcel.recycle() }
379             }
380 
381         if ((superDataSizes + parcelOne.dataSize() + parcelTwo.dataSize()).distinct().size != 1) {
382             listOf(getterType, setterType).distinct().forEach {
383                 val creatorProperties = it.staticProperties.filter { it.name == "CREATOR" }
384                 if (creatorProperties.size > 1) {
385                     expect.withMessage(
386                         "Multiple matching CREATOR fields found for" +
387                                 it.qualifiedName
388                     )
389                         .that(creatorProperties)
390                         .hasSize(1)
391                 } else {
392                     val creator = creatorProperties.single().get()
393                     if (creator !is Parcelable.Creator<*>) {
394                         expect.that(creator).isInstanceOf(Parcelable.Creator::class.java)
395                         return
396                     }
397 
398                     parcelTwo.setDataPosition(0)
399                     val parcelable = creator.createFromParcel(parcelTwo)
400                     if (parcelable::class.isSubclassOf(setterType)) {
401                         expect.withMessage(
402                             "${it.qualifiedName} which does not safely override writeToParcel " +
403                                     "cannot contain a subclass CREATOR field"
404                         )
405                             .fail()
406                     }
407                 }
408             }
409         }
410 
411         parcelOne.recycle()
412         parcelTwo.recycle()
413     }
414 
415     private fun runAssertions(params: List<Param>, before: Parcelable, after: Parcelable) {
416         params.forEach {
417             val actual = it.getFunction(after)
418             val expected = it.value()
419             val equal = it.compare(actual, expected)
420             expect.withMessage("${it.getFunctionName} was $actual, expected $expected")
421                 .that(equal)
422                 .isTrue()
423         }
424 
425         extraAssertions(before, after)
426 
427         // TODO: Handle method overloads?
428         val expectedFunctions = (getters.map { it.name }
429                 + setters.map { it.name }
430                 - excludedMethods)
431             .distinct()
432 
433         val allTestedFunctions = params.filter(Param::verifyFunctionName).flatMap {
434             listOfNotNull(it.getFunctionName, it.setFunctionName)
435         }
436         expect.that(allTestedFunctions).containsExactlyElementsIn(expectedFunctions)
437     }
438 
439     open fun extraParams(): Collection<Param?> = emptyList()
440 
441     open fun initialObject(): Parcelable = setterType.createInstance()
442 
443     open fun extraAssertions(before: Parcelable, after: Parcelable) {}
444 
445     open fun writeToParcel(parcel: Parcel, value: Parcelable) = value.writeToParcel(parcel, 0)
446 
447     data class Param(
448         val getFunctionName: String,
449         val getFunction: (Any?) -> Any?,
450         val setFunctionName: String?,
451         val setFunction: (Array<Any?>) -> Unit,
452         val value: () -> Any?,
453         val compare: (Any?, Any?) -> Boolean = Objects::equals,
454         val verifyFunctionName: Boolean = true
455     )
456 }
457