1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 import com.github.javaparser.JavaParser 18 import com.github.javaparser.ParseProblemException 19 import com.github.javaparser.ParseResult 20 import com.github.javaparser.ParserConfiguration 21 import com.github.javaparser.ast.Node 22 import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration 23 import com.github.javaparser.ast.body.FieldDeclaration 24 import com.github.javaparser.ast.body.TypeDeclaration 25 import com.github.javaparser.ast.expr.AnnotationExpr 26 import com.github.javaparser.ast.expr.Expression 27 import com.github.javaparser.ast.expr.NormalAnnotationExpr 28 import com.github.javaparser.ast.expr.SingleMemberAnnotationExpr 29 import com.github.javaparser.ast.expr.StringLiteralExpr 30 import com.github.javaparser.resolution.declarations.ResolvedReferenceTypeDeclaration 31 import com.github.javaparser.resolution.types.ResolvedPrimitiveType 32 import com.github.javaparser.resolution.types.ResolvedReferenceType 33 import com.github.javaparser.symbolsolver.JavaSymbolSolver 34 import com.github.javaparser.symbolsolver.javaparsermodel.declarations.JavaParserClassDeclaration 35 import com.github.javaparser.symbolsolver.resolution.typesolvers.CombinedTypeSolver 36 import com.github.javaparser.symbolsolver.resolution.typesolvers.MemoryTypeSolver 37 import com.github.javaparser.symbolsolver.resolution.typesolvers.ReflectionTypeSolver 38 import com.squareup.javapoet.ClassName 39 import com.squareup.javapoet.ParameterizedTypeName 40 import com.squareup.javapoet.TypeName 41 import java.nio.file.Path 42 import java.util.Optional 43 44 class PersistenceInfo( 45 val name: String, 46 val root: ClassFieldInfo, 47 val path: Path 48 ) 49 50 sealed class FieldInfo { 51 abstract val name: String 52 abstract val xmlName: String? 53 abstract val type: TypeName 54 abstract val isRequired: Boolean 55 } 56 57 class PrimitiveFieldInfo( 58 override val name: String, 59 override val xmlName: String?, 60 override val type: TypeName, 61 override val isRequired: Boolean 62 ) : FieldInfo() 63 64 class StringFieldInfo( 65 override val name: String, 66 override val xmlName: String?, 67 override val isRequired: Boolean 68 ) : FieldInfo() { 69 override val type: TypeName = ClassName.get(String::class.java) 70 } 71 72 class ClassFieldInfo( 73 override val name: String, 74 override val xmlName: String?, 75 override val type: ClassName, 76 override val isRequired: Boolean, 77 val fields: List<FieldInfo> 78 ) : FieldInfo() 79 80 class ListFieldInfo( 81 override val name: String, 82 override val xmlName: String?, 83 override val type: ParameterizedTypeName, 84 val element: ClassFieldInfo 85 ) : FieldInfo() { 86 override val isRequired: Boolean = true 87 } 88 89 fun parse(files: List<Path>): List<PersistenceInfo> { 90 val typeSolver = CombinedTypeSolver().apply { add(ReflectionTypeSolver()) } 91 val javaParser = JavaParser(ParserConfiguration() 92 .setSymbolResolver(JavaSymbolSolver(typeSolver))) 93 val compilationUnits = files.map { javaParser.parse(it).getOrThrow() } 94 val memoryTypeSolver = MemoryTypeSolver().apply { 95 for (compilationUnit in compilationUnits) { 96 for (typeDeclaration in compilationUnit.getNodesByClass<TypeDeclaration<*>>()) { 97 val name = typeDeclaration.fullyQualifiedName.getOrNull() ?: continue 98 addDeclaration(name, typeDeclaration.resolve()) 99 } 100 } 101 } 102 typeSolver.add(memoryTypeSolver) 103 return mutableListOf<PersistenceInfo>().apply { 104 for (compilationUnit in compilationUnits) { 105 val classDeclarations = compilationUnit 106 .getNodesByClass<ClassOrInterfaceDeclaration>() 107 .filter { !it.isInterface && (!it.isNestedType || it.isStatic) } 108 this += classDeclarations.mapNotNull { parsePersistenceInfo(it) } 109 } 110 } 111 } 112 113 private fun parsePersistenceInfo(classDeclaration: ClassOrInterfaceDeclaration): PersistenceInfo? { 114 val annotation = classDeclaration.getAnnotationByName("XmlPersistence").getOrNull() 115 ?: return null 116 val rootClassName = classDeclaration.nameAsString 117 val name = annotation.getMemberValue("value")?.stringLiteralValue 118 ?: "${rootClassName}Persistence" 119 val rootXmlName = classDeclaration.getAnnotationByName("XmlName").getOrNull() 120 ?.getMemberValue("value")?.stringLiteralValue 121 val root = parseClassFieldInfo( 122 rootXmlName ?: rootClassName, rootXmlName, true, classDeclaration 123 ) 124 val path = classDeclaration.findCompilationUnit().get().storage.get().path 125 .resolveSibling("$name.java") 126 return PersistenceInfo(name, root, path) 127 } 128 129 private fun parseClassFieldInfo( 130 name: String, 131 xmlName: String?, 132 isRequired: Boolean, 133 classDeclaration: ClassOrInterfaceDeclaration 134 ): ClassFieldInfo { 135 val fields = classDeclaration.fields.filterNot { it.isStatic }.map { parseFieldInfo(it) } 136 val type = classDeclaration.resolve().typeName 137 return ClassFieldInfo(name, xmlName, type, isRequired, fields) 138 } 139 140 private fun parseFieldInfo(field: FieldDeclaration): FieldInfo { 141 require(field.isPublic && field.isFinal) 142 val variable = field.variables.single() 143 val name = variable.nameAsString 144 val annotations = field.annotations + variable.type.annotations 145 val annotation = annotations.getByName("XmlName") 146 val xmlName = annotation?.getMemberValue("value")?.stringLiteralValue 147 val isRequired = annotations.getByName("NonNull") != null 148 return when (val type = variable.type.resolve()) { 149 is ResolvedPrimitiveType -> { 150 val primitiveType = type.typeName 151 PrimitiveFieldInfo(name, xmlName, primitiveType, true) 152 } 153 is ResolvedReferenceType -> { 154 when (type.qualifiedName) { 155 Boolean::class.javaObjectType.name, Byte::class.javaObjectType.name, 156 Short::class.javaObjectType.name, Char::class.javaObjectType.name, 157 Integer::class.javaObjectType.name, Long::class.javaObjectType.name, 158 Float::class.javaObjectType.name, Double::class.javaObjectType.name -> 159 PrimitiveFieldInfo(name, xmlName, type.typeName, isRequired) 160 String::class.java.name -> StringFieldInfo(name, xmlName, isRequired) 161 List::class.java.name -> { 162 requireNotNull(xmlName) 163 val elementType = type.typeParametersValues().single() 164 require(elementType is ResolvedReferenceType) 165 val listType = ParameterizedTypeName.get( 166 ClassName.get(List::class.java), elementType.typeName 167 ) 168 val element = parseClassFieldInfo( 169 "(element)", xmlName, true, elementType.classDeclaration 170 ) 171 ListFieldInfo(name, xmlName, listType, element) 172 } 173 else -> parseClassFieldInfo(name, xmlName, isRequired, type.classDeclaration) 174 } 175 } 176 else -> error(type) 177 } 178 } 179 180 private fun <T> ParseResult<T>.getOrThrow(): T = 181 if (isSuccessful) { 182 result.get() 183 } else { 184 throw ParseProblemException(problems) 185 } 186 187 private inline fun <reified T : Node> Node.getNodesByClass(): List<T> = 188 getNodesByClass(T::class.java) 189 190 private fun <T : Node> Node.getNodesByClass(klass: Class<T>): List<T> = mutableListOf<T>().apply { 191 if (klass.isInstance(this@getNodesByClass)) { 192 this += klass.cast(this@getNodesByClass) 193 } 194 for (childNode in childNodes) { 195 this += childNode.getNodesByClass(klass) 196 } 197 } 198 199 private fun <T> Optional<T>.getOrNull(): T? = orElse(null) 200 201 private fun List<AnnotationExpr>.getByName(name: String): AnnotationExpr? = 202 find { it.name.identifier == name } 203 204 private fun AnnotationExpr.getMemberValue(name: String): Expression? = 205 when (this) { 206 is NormalAnnotationExpr -> pairs.find { it.nameAsString == name }?.value 207 is SingleMemberAnnotationExpr -> if (name == "value") memberValue else null 208 else -> null 209 } 210 211 private val Expression.stringLiteralValue: String 212 get() { 213 require(this is StringLiteralExpr) 214 return value 215 } 216 217 private val ResolvedReferenceType.classDeclaration: ClassOrInterfaceDeclaration 218 get() { 219 val resolvedClassDeclaration = typeDeclaration 220 require(resolvedClassDeclaration is JavaParserClassDeclaration) 221 return resolvedClassDeclaration.wrappedNode 222 } 223 224 private val ResolvedPrimitiveType.typeName: TypeName 225 get() = 226 when (this) { 227 ResolvedPrimitiveType.BOOLEAN -> TypeName.BOOLEAN 228 ResolvedPrimitiveType.BYTE -> TypeName.BYTE 229 ResolvedPrimitiveType.SHORT -> TypeName.SHORT 230 ResolvedPrimitiveType.CHAR -> TypeName.CHAR 231 ResolvedPrimitiveType.INT -> TypeName.INT 232 ResolvedPrimitiveType.LONG -> TypeName.LONG 233 ResolvedPrimitiveType.FLOAT -> TypeName.FLOAT 234 ResolvedPrimitiveType.DOUBLE -> TypeName.DOUBLE 235 } 236 237 // This doesn't support type parameters. 238 private val ResolvedReferenceType.typeName: TypeName 239 get() = typeDeclaration.typeName 240 241 private val ResolvedReferenceTypeDeclaration.typeName: ClassName 242 get() { 243 val packageName = packageName 244 val classNames = className.split(".") 245 val topLevelClassName = classNames.first() 246 val nestedClassNames = classNames.drop(1) 247 return ClassName.get(packageName, topLevelClassName, *nestedClassNames.toTypedArray()) 248 } 249