1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.pm.test.parsing.parcelling 18 19 import android.os.Parcel 20 import android.os.Parcelable 21 import com.android.server.pm.test.util.IgnoreableExpect 22 import com.google.common.truth.Expect 23 import org.junit.Before 24 import org.junit.Rule 25 import org.junit.Test 26 import org.junit.rules.TestRule 27 import java.util.Objects 28 import kotlin.contracts.ExperimentalContracts 29 import kotlin.reflect.KClass 30 import kotlin.reflect.KFunction 31 import kotlin.reflect.KFunction1 32 import kotlin.reflect.KFunction2 33 import kotlin.reflect.KFunction3 34 import kotlin.reflect.KVisibility 35 import kotlin.reflect.full.allSuperclasses 36 import kotlin.reflect.full.createInstance 37 import kotlin.reflect.full.isSubclassOf 38 import kotlin.reflect.full.memberFunctions 39 import kotlin.reflect.full.memberProperties 40 import kotlin.reflect.full.staticProperties 41 import kotlin.reflect.jvm.jvmErasure 42 43 44 @ExperimentalContracts 45 abstract class ParcelableComponentTest( 46 private val getterType: KClass<*>, 47 private val setterType: KClass<out Parcelable> 48 ) { 49 50 companion object { 51 private val DEFAULT_EXCLUDED = listOf( 52 // Java 53 "toString", 54 "equals", 55 "hashCode", 56 // Parcelable 57 "getStability", 58 "describeContents", 59 "writeToParcel", 60 // @DataClass 61 "__metadata" 62 ) 63 } 64 65 internal val ignoreableExpect = IgnoreableExpect() 66 67 // Hides internal type 68 @get:Rule 69 val ignoreableAsTestRule: TestRule = ignoreableExpect 70 71 val expect: Expect 72 get() = ignoreableExpect.expect 73 74 protected var testCounter = 1 75 76 protected abstract val defaultImpl: Any 77 protected abstract val creator: Parcelable.Creator<out Parcelable> 78 79 protected open val excludedMethods: Collection<String> = emptyList() 80 81 protected abstract val baseParams: Collection<KFunction1<*, Any?>> 82 83 private val getters = getterType.memberFunctions 84 .filterNot { DEFAULT_EXCLUDED.contains(it.name) } 85 86 private val setters = setterType.memberFunctions 87 .filterNot { DEFAULT_EXCLUDED.contains(it.name) } 88 89 constructor(kClass: KClass<out Parcelable>) : this(kClass, kClass) 90 91 @Before 92 fun checkNoPublicFields() { 93 // Fields are not currently testable, and the idea is to enforce interface access for 94 // immutability purposes, so disallow any public fields from existing. 95 expect.that(getterType.memberProperties.filter { it.visibility == KVisibility.PUBLIC } 96 .filterNot { DEFAULT_EXCLUDED.contains(it.name) }) 97 .isEmpty() 98 } 99 100 @Suppress("UNCHECKED_CAST") 101 private fun <ObjectType, ReturnType> buildParams( 102 getFunction: KFunction1<ObjectType, ReturnType>, 103 ): Param? { 104 return buildParams<ObjectType, ReturnType, ReturnType, ReturnType>( 105 getFunction, 106 autoValue(getFunction) as ReturnType ?: return null 107 ) 108 } 109 110 @Suppress("UNCHECKED_CAST") 111 private fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> buildParams( 112 getFunction: KFunction1<ObjectType, ReturnType>, 113 value: SetType, 114 ): Param? { 115 return getSetByValue<ObjectType, ReturnType, SetType, Any?>( 116 getFunction, 117 findSetFunction(getFunction) ?: return null, 118 value 119 ) 120 } 121 122 @Suppress("UNCHECKED_CAST") 123 private fun <ObjectType, ReturnType> findSetFunction( 124 getFunction: KFunction1<ObjectType, ReturnType> 125 ): KFunction2<ObjectType, ReturnType, Any?>? { 126 val getFunctionName = getFunction.name 127 val prefix = when { 128 getFunctionName.startsWith("get") -> "get" 129 getFunctionName.startsWith("is") -> "is" 130 getFunctionName.startsWith("has") -> "has" 131 else -> throw IllegalArgumentException("Unsupported method name $getFunctionName") 132 } 133 val setFunctionName = "set" + getFunctionName.removePrefix(prefix) 134 val setFunction = setters.filter { it.name == setFunctionName } 135 .minByOrNull { it.parameters.size } 136 137 if (setFunction == null) { 138 expect.withMessage("$getFunctionName does not have corresponding $setFunctionName") 139 .fail() 140 return null 141 } 142 143 return setFunction as KFunction2<ObjectType, ReturnType, Any?> 144 } 145 146 @Suppress("UNCHECKED_CAST") 147 private fun <ObjectType, ReturnType, SetType> findAddFunction( 148 getFunction: KFunction1<ObjectType, ReturnType> 149 ): KFunction2<ObjectType, SetType, Any?>? { 150 val getFunctionName = getFunction.name 151 if (!getFunctionName.startsWith("get")) { 152 throw IllegalArgumentException("Unsupported method name $getFunctionName") 153 } 154 155 val setFunctionName = "add" + getFunctionName.removePrefix("get").run { 156 // Remove plurality 157 when { 158 endsWith("ies") -> "${removeSuffix("ies")}y" 159 endsWith("s") -> removeSuffix("s") 160 else -> this 161 } 162 } 163 164 val setFunction = setters.filter { it.name == setFunctionName } 165 .minByOrNull { it.parameters.size } 166 167 if (setFunction == null) { 168 expect.withMessage("$getFunctionName does not have corresponding $setFunctionName") 169 .fail() 170 return null 171 } 172 173 return setFunction as KFunction2<ObjectType, SetType, Any?> 174 } 175 176 protected fun <ObjectType, ReturnType> getter( 177 getFunction: KFunction1<ObjectType, ReturnType>, 178 valueToSet: ReturnType 179 ) = buildParams<ObjectType, ReturnType, ReturnType, ReturnType>(getFunction, valueToSet) 180 181 protected fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> getter( 182 getFunction: KFunction1<ObjectType, ReturnType>, 183 expectedValue: CompareType, 184 valueToSet: SetType 185 ): Param? { 186 return getSetByValue( 187 getFunction, 188 findSetFunction(getFunction) ?: return null, 189 value = expectedValue, 190 transformSet = { valueToSet } 191 ) 192 } 193 194 @Suppress("UNCHECKED_CAST") 195 protected fun <ObjectType, ReturnType> adder( 196 getFunction: KFunction1<ObjectType, ReturnType>, 197 value: ReturnType, 198 ): Param? { 199 return getSetByValue( 200 getFunction, 201 findAddFunction<ObjectType, Any?, ReturnType>(getFunction) ?: return null, 202 value, 203 transformGet = { 204 // Primitive arrays don't implement Iterable, so cast manually 205 when (it) { 206 is BooleanArray -> it.singleOrNull() 207 is IntArray -> it.singleOrNull() 208 is LongArray -> it.singleOrNull() 209 is Iterable<*> -> it.singleOrNull() 210 else -> null 211 } 212 }, 213 ) 214 } 215 216 /** 217 * Method to provide custom getter and setter logic for values which are not simple primitives 218 * or cannot be directly compared using [Objects.equals]. 219 * 220 * @param getFunction the getter function which will be called and marked as tested 221 * @param setFunction the setter function which will be called and marked as tested 222 * @param value the value for comparison through the parcel-unparcel cycle, which can be 223 * anything, like the [String] ID of an inner object 224 * @param transformGet the function to transform the result of [getFunction] into [value] 225 * @param transformSet the function to transform [value] into an input for [setFunction] 226 * @param compare the function that compares the pre/post-parcel [value] objects 227 */ 228 @Suppress("UNCHECKED_CAST") 229 protected fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> getSetByValue( 230 getFunction: KFunction1<ObjectType, ReturnType>, 231 setFunction: KFunction2<ObjectType, SetType, Any?>, 232 value: CompareType, 233 transformGet: (ReturnType) -> CompareType = { it as CompareType }, 234 transformSet: (CompareType) -> SetType = { it as SetType }, 235 compare: (CompareType, CompareType) -> Boolean? = Objects::equals 236 ) = Param( 237 getFunction.name, 238 { transformGet(getFunction.call(it as ObjectType)) }, 239 setFunction.name, 240 { setFunction.call(it.first() as ObjectType, transformSet(it[1] as CompareType)) }, 241 { value }, 242 { first, second -> compare(first as CompareType, second as CompareType) == true } 243 ) 244 245 /** 246 * Variant of [getSetByValue] that allows specifying a non-member [getFunction]. Mostly used 247 * for AndroidPackageHidden for APIs which are hidden from the interface. 248 */ 249 @Suppress("UNCHECKED_CAST") 250 protected fun <ObjectType, ReturnType, SetType : Any?, CompareType : Any?> getSetByValue( 251 getFunction: (ObjectType) -> ReturnType, 252 getFunctionName: String, 253 setFunction: KFunction2<ObjectType, SetType, Any?>, 254 value: CompareType, 255 transformGet: (ReturnType) -> CompareType = { it as CompareType }, 256 transformSet: (CompareType) -> SetType = { it as SetType }, 257 compare: (CompareType, CompareType) -> Boolean? = Objects::equals 258 ) = Param( 259 getFunctionName, 260 { transformGet(getFunction(it as ObjectType)) }, 261 setFunction.name, 262 { setFunction.call(it.first() as ObjectType, transformSet(it[1] as CompareType)) }, 263 { value }, 264 { first, second -> compare(first as CompareType, second as CompareType) == true }, 265 verifyFunctionName = false 266 ) 267 268 /** 269 * Variant of [getSetByValue] that allows specifying a [setFunction] with 2 inputs. 270 */ 271 @Suppress("UNCHECKED_CAST") 272 protected fun <ObjectType, ReturnType, SetType1 : Any?, SetType2 : Any?, CompareType : Any?> 273 getSetByValue2( 274 getFunction: KFunction1<ObjectType, ReturnType>, 275 setFunction: KFunction3<ObjectType, SetType1, SetType2, Any>, 276 value: CompareType, 277 transformGet: (ReturnType) -> CompareType = { it as CompareType }, 278 transformSet: (CompareType) -> Pair<SetType1, SetType2> = 279 { it as Pair<SetType1, SetType2> }, 280 compare: (CompareType, CompareType) -> Boolean = Objects::equals 281 ) = Param( 282 getFunction.name, 283 { transformGet(getFunction.call(it as ObjectType)) }, 284 setFunction.name, 285 { 286 val pair = transformSet(it[1] as CompareType) 287 setFunction.call(it.first() as ObjectType, pair.first, pair.second) 288 }, 289 { value }, 290 { first, second -> compare(first as CompareType, second as CompareType) } 291 ) 292 293 protected fun autoValue(getFunction: KFunction<*>) = when (getFunction.returnType.jvmErasure) { 294 Boolean::class -> (getFunction.call(defaultImpl) as Boolean?)?.not() ?: true 295 CharSequence::class, 296 String::class -> getFunction.name + "TEST" 297 Int::class -> testCounter++ 298 Long::class -> (testCounter++).toLong() 299 Float::class -> (testCounter++).toFloat() 300 else -> { 301 expect.withMessage("${getFunction.name} needs to provide value").fail() 302 null 303 } 304 } 305 306 /** 307 * Verifies two instances are equivalent via a series of properties. For use when a public API 308 * class has not implemented equals. 309 */ 310 @Suppress("UNCHECKED_CAST") 311 protected fun <T : Any> equalBy( 312 first: T?, 313 second: T?, 314 vararg properties: (T) -> Any? 315 ) = properties.all { property -> 316 first?.let { property(it) } == second?.let { property(it) } 317 } 318 319 fun buildBefore(): Pair<List<Param>, Parcelable> { 320 val params = baseParams.mapNotNull(::buildParams) + extraParams().filterNotNull() 321 val before = initialObject() 322 323 params.forEach { it.setFunction(arrayOf(before, it.value())) } 324 finalizeObject(before) 325 return params to before 326 } 327 328 protected open fun finalizeObject(parcelable: Parcelable) { 329 } 330 331 @Test 332 fun valueComparison() { 333 val (params, before) = buildBefore() 334 335 val parcel = Parcel.obtain() 336 writeToParcel(parcel, before) 337 338 val dataSize = parcel.dataSize() 339 340 parcel.setDataPosition(0) 341 342 val baseline = initialObject() 343 finalizeObject(baseline) 344 345 val baselineParcel = Parcel.obtain() 346 writeToParcel(baselineParcel, baseline) 347 348 // Check that something substantial actually changed in the test object 349 expect.that(parcel.dataSize()).isGreaterThan(baselineParcel.dataSize()) 350 351 val after = creator.createFromParcel(parcel) 352 353 expect.withMessage("Mismatched write and read data sizes") 354 .that(parcel.dataPosition()) 355 .isEqualTo(dataSize) 356 357 parcel.recycle() 358 baselineParcel.recycle() 359 360 runAssertions(params, before, after) 361 } 362 363 @Test 364 open fun parcellingSize() { 365 val parcelOne = Parcel.obtain() 366 writeToParcel(parcelOne, initialObject()) 367 368 val parcelTwo = Parcel.obtain() 369 initialObject().writeToParcel(parcelTwo, 0) 370 371 val superDataSizes = setterType.allSuperclasses 372 .filter { it.isSubclassOf(Parcelable::class) } 373 .mapNotNull { it.memberFunctions.find { it.name == "writeToParcel" } } 374 .filter { it.isFinal } 375 .map { 376 val parcel = Parcel.obtain() 377 initialObject().writeToParcel(parcel, 0) 378 parcel.dataSize().also { parcel.recycle() } 379 } 380 381 if ((superDataSizes + parcelOne.dataSize() + parcelTwo.dataSize()).distinct().size != 1) { 382 listOf(getterType, setterType).distinct().forEach { 383 val creatorProperties = it.staticProperties.filter { it.name == "CREATOR" } 384 if (creatorProperties.size > 1) { 385 expect.withMessage( 386 "Multiple matching CREATOR fields found for" + 387 it.qualifiedName 388 ) 389 .that(creatorProperties) 390 .hasSize(1) 391 } else { 392 val creator = creatorProperties.single().get() 393 if (creator !is Parcelable.Creator<*>) { 394 expect.that(creator).isInstanceOf(Parcelable.Creator::class.java) 395 return 396 } 397 398 parcelTwo.setDataPosition(0) 399 val parcelable = creator.createFromParcel(parcelTwo) 400 if (parcelable::class.isSubclassOf(setterType)) { 401 expect.withMessage( 402 "${it.qualifiedName} which does not safely override writeToParcel " + 403 "cannot contain a subclass CREATOR field" 404 ) 405 .fail() 406 } 407 } 408 } 409 } 410 411 parcelOne.recycle() 412 parcelTwo.recycle() 413 } 414 415 private fun runAssertions(params: List<Param>, before: Parcelable, after: Parcelable) { 416 params.forEach { 417 val actual = it.getFunction(after) 418 val expected = it.value() 419 val equal = it.compare(actual, expected) 420 expect.withMessage("${it.getFunctionName} was $actual, expected $expected") 421 .that(equal) 422 .isTrue() 423 } 424 425 extraAssertions(before, after) 426 427 // TODO: Handle method overloads? 428 val expectedFunctions = (getters.map { it.name } 429 + setters.map { it.name } 430 - excludedMethods) 431 .distinct() 432 433 val allTestedFunctions = params.filter(Param::verifyFunctionName).flatMap { 434 listOfNotNull(it.getFunctionName, it.setFunctionName) 435 } 436 expect.that(allTestedFunctions).containsExactlyElementsIn(expectedFunctions) 437 } 438 439 open fun extraParams(): Collection<Param?> = emptyList() 440 441 open fun initialObject(): Parcelable = setterType.createInstance() 442 443 open fun extraAssertions(before: Parcelable, after: Parcelable) {} 444 445 open fun writeToParcel(parcel: Parcel, value: Parcelable) = value.writeToParcel(parcel, 0) 446 447 data class Param( 448 val getFunctionName: String, 449 val getFunction: (Any?) -> Any?, 450 val setFunctionName: String?, 451 val setFunction: (Array<Any?>) -> Unit, 452 val value: () -> Any?, 453 val compare: (Any?, Any?) -> Boolean = Objects::equals, 454 val verifyFunctionName: Boolean = true 455 ) 456 } 457