View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.maven.di.impl;
20  
21  import java.io.BufferedReader;
22  import java.io.InputStream;
23  import java.io.InputStreamReader;
24  import java.lang.annotation.Annotation;
25  import java.lang.reflect.Method;
26  import java.lang.reflect.Modifier;
27  import java.lang.reflect.Type;
28  import java.net.URL;
29  import java.util.AbstractList;
30  import java.util.AbstractMap;
31  import java.util.AbstractSet;
32  import java.util.ArrayList;
33  import java.util.Arrays;
34  import java.util.Comparator;
35  import java.util.Enumeration;
36  import java.util.HashMap;
37  import java.util.HashSet;
38  import java.util.Iterator;
39  import java.util.LinkedHashSet;
40  import java.util.List;
41  import java.util.Map;
42  import java.util.Objects;
43  import java.util.Set;
44  import java.util.concurrent.ConcurrentHashMap;
45  import java.util.function.Function;
46  import java.util.function.Supplier;
47  import java.util.stream.Collectors;
48  import java.util.stream.Stream;
49  
50  import org.apache.maven.api.annotations.Nonnull;
51  import org.apache.maven.api.di.Provides;
52  import org.apache.maven.api.di.Qualifier;
53  import org.apache.maven.api.di.Singleton;
54  import org.apache.maven.api.di.Typed;
55  import org.apache.maven.di.Injector;
56  import org.apache.maven.di.Key;
57  import org.apache.maven.di.Scope;
58  
59  public class InjectorImpl implements Injector {
60  
61      private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
62      private final Map<Class<? extends Annotation>, Supplier<Scope>> scopes = new HashMap<>();
63      private final Set<String> loadedUrls = new HashSet<>();
64  
65      public InjectorImpl() {
66          bindScope(Singleton.class, new SingletonScope());
67      }
68  
69      @Nonnull
70      @Override
71      public <T> T getInstance(@Nonnull Class<T> key) {
72          return getInstance(Key.of(key));
73      }
74  
75      @Nonnull
76      @Override
77      public <T> T getInstance(@Nonnull Key<T> key) {
78          return getCompiledBinding(new Dependency<>(key, false)).get();
79      }
80  
81      @SuppressWarnings("unchecked")
82      @Override
83      public <T> void injectInstance(@Nonnull T instance) {
84          ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
85                  .compile(this::getCompiledBinding)
86                  .accept(instance);
87      }
88  
89      @Nonnull
90      @Override
91      public Injector discover(@Nonnull ClassLoader classLoader) {
92          try {
93              Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
94              while (enumeration.hasMoreElements()) {
95                  URL url = enumeration.nextElement();
96                  if (loadedUrls.add(url.toExternalForm())) {
97                      try (InputStream is = url.openStream();
98                              BufferedReader reader =
99                                      new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
100                         for (String line :
101                                 reader.lines().filter(l -> !l.startsWith("#")).toList()) {
102                             Class<?> clazz = classLoader.loadClass(line);
103                             bindImplicit(clazz);
104                         }
105                     }
106                 }
107             }
108         } catch (Exception e) {
109             throw new DIException("Error while discovering DI classes from classLoader", e);
110         }
111         return this;
112     }
113 
114     @Nonnull
115     @Override
116     public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Scope scope) {
117         return bindScope(scopeAnnotation, () -> scope);
118     }
119 
120     @Nonnull
121     @Override
122     public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Supplier<Scope> scope) {
123         if (scopes.put(scopeAnnotation, scope) != null) {
124             throw new DIException(
125                     "Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
126         }
127         return this;
128     }
129 
130     @Nonnull
131     @Override
132     public <U> Injector bindInstance(@Nonnull Class<U> clazz, @Nonnull U instance) {
133         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
134         Binding<U> binding = Binding.toInstance(instance);
135         return doBind(key, binding);
136     }
137 
138     @Nonnull
139     @Override
140     public Injector bindImplicit(@Nonnull Class<?> clazz) {
141         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
142         if (clazz.isInterface()) {
143             bindings.computeIfAbsent(key, $ -> new HashSet<>());
144             if (key.getQualifier() != null) {
145                 bindings.computeIfAbsent(Key.ofType(clazz), $ -> new HashSet<>());
146             }
147         } else if (!Modifier.isAbstract(clazz.getModifiers())) {
148             Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
149             doBind(key, binding);
150         }
151         return this;
152     }
153 
154     private final LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
155 
156     private Injector doBind(Key<?> key, Binding<?> binding) {
157         if (!current.add(key)) {
158             current.add(key);
159             throw new DIException("Circular references: " + current);
160         }
161         try {
162             doBindImplicit(key, binding);
163             Class<?> cls = key.getRawType().getSuperclass();
164             while (cls != Object.class && cls != null) {
165                 doBindImplicit(Key.of(cls, key.getQualifier()), binding);
166                 if (key.getQualifier() != null) {
167                     bind(Key.ofType(cls), binding);
168                 }
169                 cls = cls.getSuperclass();
170             }
171             return this;
172         } finally {
173             current.remove(key);
174         }
175     }
176 
177     protected <U> Injector bind(Key<U> key, Binding<U> b) {
178         Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
179         bindingSet.add(b);
180         return this;
181     }
182 
183     @SuppressWarnings({"unchecked", "rawtypes"})
184     protected <T> Set<Binding<T>> getBindings(Key<T> key) {
185         return (Set) bindings.get(key);
186     }
187 
188     protected Set<Key<?>> getBoundKeys() {
189         return bindings.keySet();
190     }
191 
192     public Map<Key<?>, Set<Binding<?>>> getBindings() {
193         return bindings;
194     }
195 
196     public <Q> Supplier<Q> getCompiledBinding(Dependency<Q> dep) {
197         Key<Q> key = dep.key();
198         Set<Binding<Q>> res = getBindings(key);
199         if (res != null && !res.isEmpty()) {
200             List<Binding<Q>> bindingList = new ArrayList<>(res);
201             Comparator<Binding<Q>> comparing = Comparator.comparing(Binding::getPriority);
202             bindingList.sort(comparing.reversed());
203             Binding<Q> binding = bindingList.get(0);
204             return compile(binding);
205         }
206         if (key.getRawType() == List.class) {
207             Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
208             if (res2 != null) {
209                 List<Supplier<Object>> list = res2.stream().map(this::compile).collect(Collectors.toList());
210                 //noinspection unchecked
211                 return () -> (Q) list(list, Supplier::get);
212             }
213         }
214         if (key.getRawType() == Map.class) {
215             Key<?> k = key.getTypeParameter(0);
216             Key<Object> v = key.getTypeParameter(1);
217             Set<Binding<Object>> res2 = getBindings(v);
218             if (k.getRawType() == String.class && res2 != null) {
219                 Map<String, Supplier<Object>> map = res2.stream()
220                         .filter(b -> b.getOriginalKey() == null
221                                 || b.getOriginalKey().getQualifier() == null
222                                 || b.getOriginalKey().getQualifier() instanceof String)
223                         .collect(Collectors.toMap(
224                                 b -> (String)
225                                         (b.getOriginalKey() != null
226                                                 ? b.getOriginalKey().getQualifier()
227                                                 : null),
228                                 this::compile));
229                 //noinspection unchecked
230                 return () -> (Q) map(map, Supplier::get);
231             }
232         }
233         if (dep.optional()) {
234             return () -> null;
235         }
236         throw new DIException("No binding to construct an instance for key "
237                 + key.getDisplayString() + ".  Existing bindings:\n"
238                 + getBoundKeys().stream()
239                         .map(Key::toString)
240                         .map(String::trim)
241                         .sorted()
242                         .distinct()
243                         .collect(Collectors.joining("\n - ", " - ", "")));
244     }
245 
246     @SuppressWarnings("unchecked")
247     protected <Q> Supplier<Q> compile(Binding<Q> binding) {
248         Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
249         if (binding.getScope() != null) {
250             Scope scope = scopes.entrySet().stream()
251                     .filter(e -> e.getKey().isInstance(binding.getScope()))
252                     .map(Map.Entry::getValue)
253                     .findFirst()
254                     .orElseThrow(() -> new DIException("Scope not bound for annotation "
255                             + binding.getScope().annotationType()))
256                     .get();
257             compiled = scope.scope((Key<Q>) binding.getOriginalKey(), compiled);
258         }
259         return compiled;
260     }
261 
262     protected void doBindImplicit(Key<?> key, Binding<?> binding) {
263         if (binding != null) {
264             // For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
265             Object qualifier = key.getQualifier();
266             Class<?> type = key.getRawType();
267             Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
268             for (Type t : Types.getAllSuperTypes(type)) {
269                 if (types == null || types.contains(Types.getRawType(t))) {
270                     bind(Key.ofType(t, qualifier), binding);
271                     if (qualifier != null) {
272                         bind(Key.ofType(t), binding);
273                     }
274                 }
275             }
276         }
277         // Bind inner classes
278         for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
279             boolean hasQualifier = Stream.of(inner.getAnnotations())
280                     .anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
281             if (hasQualifier) {
282                 bindImplicit(inner);
283             }
284         }
285         // Bind inner providers
286         for (Method method : key.getRawType().getDeclaredMethods()) {
287             if (method.isAnnotationPresent(Provides.class)) {
288                 if (method.getTypeParameters().length != 0) {
289                     throw new DIException("Parameterized method are not supported " + method);
290                 }
291                 Object qualifier = ReflectionUtils.qualifierOf(method);
292                 Annotation scope = ReflectionUtils.scopeOf(method);
293                 Type returnType = method.getGenericReturnType();
294                 Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
295                 Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
296                 for (Type t : Types.getAllSuperTypes(returnType)) {
297                     if (types == null || types.contains(Types.getRawType(t))) {
298                         bind(Key.ofType(t, qualifier), bind);
299                         if (qualifier != null) {
300                             bind(Key.ofType(t), bind);
301                         }
302                     }
303                 }
304             }
305         }
306     }
307 
308     private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
309         if (typed != null) {
310             Class<?>[] typesArray = typed.value();
311             if (typesArray == null || typesArray.length == 0) {
312                 Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
313                 types.add(Object.class);
314                 return types;
315             } else {
316                 return new HashSet<>(Arrays.asList(typesArray));
317             }
318         } else {
319             return null;
320         }
321     }
322 
323     protected <K, V, T> Map<K, V> map(Map<K, T> map, Function<T, V> mapper) {
324         return new WrappingMap<>(map, mapper);
325     }
326 
327     private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
328 
329         private final Map<K, T> delegate;
330         private final Function<T, V> mapper;
331 
332         WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
333             this.delegate = delegate;
334             this.mapper = mapper;
335         }
336 
337         @Override
338         public Set<Entry<K, V>> entrySet() {
339             return new AbstractSet<>() {
340                 @Override
341                 public Iterator<Entry<K, V>> iterator() {
342                     Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
343                     return new Iterator<>() {
344                         @Override
345                         public boolean hasNext() {
346                             return it.hasNext();
347                         }
348 
349                         @Override
350                         public Entry<K, V> next() {
351                             Entry<K, T> n = it.next();
352                             return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
353                         }
354                     };
355                 }
356 
357                 @Override
358                 public int size() {
359                     return delegate.size();
360                 }
361             };
362         }
363     }
364 
365     protected <Q, T> List<Q> list(List<T> bindingList, Function<T, Q> mapper) {
366         return new WrappingList<>(bindingList, mapper);
367     }
368 
369     private static class WrappingList<Q, T> extends AbstractList<Q> {
370 
371         private final List<T> delegate;
372         private final Function<T, Q> mapper;
373 
374         WrappingList(List<T> delegate, Function<T, Q> mapper) {
375             this.delegate = delegate;
376             this.mapper = mapper;
377         }
378 
379         @Override
380         public Q get(int index) {
381             return mapper.apply(delegate.get(index));
382         }
383 
384         @Override
385         public int size() {
386             return delegate.size();
387         }
388     }
389 
390     private static class SingletonScope implements Scope {
391         Map<Key<?>, java.util.function.Supplier<?>> cache = new ConcurrentHashMap<>();
392 
393         @Nonnull
394         @SuppressWarnings("unchecked")
395         @Override
396         public <T> java.util.function.Supplier<T> scope(
397                 @Nonnull Key<T> key, @Nonnull java.util.function.Supplier<T> unscoped) {
398             return (java.util.function.Supplier<T>)
399                     cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
400                         volatile T instance;
401 
402                         @Override
403                         public T get() {
404                             if (instance == null) {
405                                 synchronized (this) {
406                                     if (instance == null) {
407                                         instance = unscoped.get();
408                                     }
409                                 }
410                             }
411                             return instance;
412                         }
413                     });
414         }
415     }
416 }