View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.maven.di.impl;
20  
21  import java.io.BufferedReader;
22  import java.io.InputStream;
23  import java.io.InputStreamReader;
24  import java.lang.annotation.Annotation;
25  import java.lang.reflect.Method;
26  import java.lang.reflect.Modifier;
27  import java.lang.reflect.Type;
28  import java.net.URL;
29  import java.util.AbstractList;
30  import java.util.AbstractMap;
31  import java.util.AbstractSet;
32  import java.util.ArrayList;
33  import java.util.Arrays;
34  import java.util.Comparator;
35  import java.util.Enumeration;
36  import java.util.HashMap;
37  import java.util.HashSet;
38  import java.util.Iterator;
39  import java.util.LinkedHashSet;
40  import java.util.List;
41  import java.util.Map;
42  import java.util.Objects;
43  import java.util.Set;
44  import java.util.concurrent.ConcurrentHashMap;
45  import java.util.function.Function;
46  import java.util.function.Supplier;
47  import java.util.stream.Collectors;
48  import java.util.stream.Stream;
49  
50  import org.apache.maven.api.di.Provides;
51  import org.apache.maven.api.di.Qualifier;
52  import org.apache.maven.api.di.Singleton;
53  import org.apache.maven.api.di.Typed;
54  import org.apache.maven.di.Injector;
55  import org.apache.maven.di.Key;
56  import org.apache.maven.di.Scope;
57  
58  public class InjectorImpl implements Injector {
59  
60      private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
61      private final Map<Class<? extends Annotation>, Supplier<Scope>> scopes = new HashMap<>();
62  
63      public InjectorImpl() {
64          bindScope(Singleton.class, new SingletonScope());
65      }
66  
67      public <T> T getInstance(Class<T> key) {
68          return getInstance(Key.of(key));
69      }
70  
71      public <T> T getInstance(Key<T> key) {
72          return getCompiledBinding(key).get();
73      }
74  
75      @SuppressWarnings("unchecked")
76      @Override
77      public <T> void injectInstance(T instance) {
78          ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
79                  .compile(this::getCompiledBinding)
80                  .accept(instance);
81      }
82  
83      @Override
84      public Injector discover(ClassLoader classLoader) {
85          try {
86              Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
87              while (enumeration.hasMoreElements()) {
88                  try (InputStream is = enumeration.nextElement().openStream();
89                          BufferedReader reader = new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
90                      for (String line :
91                              reader.lines().filter(l -> !l.startsWith("#")).collect(Collectors.toList())) {
92                          Class<?> clazz = classLoader.loadClass(line);
93                          bindImplicit(clazz);
94                      }
95                  }
96              }
97          } catch (Exception e) {
98              throw new DIException("Error while discovering DI classes from classLoader", e);
99          }
100         return this;
101     }
102 
103     public Injector bindScope(Class<? extends Annotation> scopeAnnotation, Scope scope) {
104         return bindScope(scopeAnnotation, () -> scope);
105     }
106 
107     public Injector bindScope(Class<? extends Annotation> scopeAnnotation, Supplier<Scope> scope) {
108         if (scopes.put(scopeAnnotation, scope) != null) {
109             throw new DIException(
110                     "Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
111         }
112         return this;
113     }
114 
115     public <U> Injector bindInstance(Class<U> clazz, U instance) {
116         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
117         Binding<U> binding = Binding.toInstance(instance);
118         return doBind(key, binding);
119     }
120 
121     @Override
122     public Injector bindImplicit(Class<?> clazz) {
123         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
124         if (clazz.isInterface()) {
125             bindings.computeIfAbsent(key, $ -> new HashSet<>());
126         } else if (!Modifier.isAbstract(clazz.getModifiers())) {
127             Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
128             doBind(key, binding);
129         }
130         return this;
131     }
132 
133     private LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
134 
135     private Injector doBind(Key<?> key, Binding<?> binding) {
136         if (!current.add(key)) {
137             current.add(key);
138             throw new DIException("Circular references: " + current);
139         }
140         try {
141             doBindImplicit(key, binding);
142             Class<?> cls = key.getRawType().getSuperclass();
143             while (cls != Object.class && cls != null) {
144                 key = Key.of(cls, key.getQualifier());
145                 doBindImplicit(key, binding);
146                 cls = cls.getSuperclass();
147             }
148             return this;
149         } finally {
150             current.remove(key);
151         }
152     }
153 
154     protected <U> Injector bind(Key<U> key, Binding<U> b) {
155         Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
156         bindingSet.add(b);
157         return this;
158     }
159 
160     @SuppressWarnings({"unchecked", "rawtypes"})
161     protected <T> Set<Binding<T>> getBindings(Key<T> key) {
162         return (Set) bindings.get(key);
163     }
164 
165     protected Set<Key<?>> getBoundKeys() {
166         return bindings.keySet();
167     }
168 
169     public Map<Key<?>, Set<Binding<?>>> getBindings() {
170         return bindings;
171     }
172 
173     public <Q> Supplier<Q> getCompiledBinding(Key<Q> key) {
174         Set<Binding<Q>> res = getBindings(key);
175         if (res != null && !res.isEmpty()) {
176             List<Binding<Q>> bindingList = new ArrayList<>(res);
177             Comparator<Binding<Q>> comparing = Comparator.comparing(Binding::getPriority);
178             bindingList.sort(comparing.reversed());
179             Binding<Q> binding = bindingList.get(0);
180             return compile(binding);
181         }
182         if (key.getRawType() == List.class) {
183             Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
184             if (res2 != null) {
185                 List<Supplier<Object>> list = res2.stream().map(this::compile).collect(Collectors.toList());
186                 //noinspection unchecked
187                 return () -> (Q) list(list);
188             }
189         }
190         if (key.getRawType() == Map.class) {
191             Key<?> k = key.getTypeParameter(0);
192             Key<Object> v = key.getTypeParameter(1);
193             Set<Binding<Object>> res2 = getBindings(v);
194             if (k.getRawType() == String.class && res2 != null) {
195                 Map<String, Supplier<Object>> map = res2.stream()
196                         .filter(b -> b.getOriginalKey() == null
197                                 || b.getOriginalKey().getQualifier() == null
198                                 || b.getOriginalKey().getQualifier() instanceof String)
199                         .collect(Collectors.toMap(
200                                 b -> (String)
201                                         (b.getOriginalKey() != null
202                                                 ? b.getOriginalKey().getQualifier()
203                                                 : null),
204                                 this::compile));
205                 //noinspection unchecked
206                 return () -> (Q) map(map);
207             }
208         }
209         throw new DIException("No binding to construct an instance for key "
210                 + key.getDisplayString() + ".  Existing bindings:\n"
211                 + getBoundKeys().stream()
212                         .map(Key::toString)
213                         .map(String::trim)
214                         .sorted()
215                         .distinct()
216                         .collect(Collectors.joining("\n - ", " - ", "")));
217     }
218 
219     @SuppressWarnings("unchecked")
220     protected <Q> Supplier<Q> compile(Binding<Q> binding) {
221         Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
222         if (binding.getScope() != null) {
223             Scope scope = scopes.entrySet().stream()
224                     .filter(e -> e.getKey().isInstance(binding.getScope()))
225                     .map(Map.Entry::getValue)
226                     .findFirst()
227                     .orElseThrow(() -> new DIException("Scope not bound for annotation "
228                             + binding.getScope().annotationType()))
229                     .get();
230             compiled = scope.scope((Key<Q>) binding.getOriginalKey(), binding.getScope(), compiled);
231         }
232         return compiled;
233     }
234 
235     protected void doBindImplicit(Key<?> key, Binding<?> binding) {
236         if (binding != null) {
237             // For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
238             Object qualifier = key.getQualifier();
239             Class<?> type = key.getRawType();
240             Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
241             for (Type t : Types.getAllSuperTypes(type)) {
242                 if (types == null || types.contains(Types.getRawType(t))) {
243                     bind(Key.ofType(t, qualifier), binding);
244                     if (qualifier != null) {
245                         bind(Key.ofType(t), binding);
246                     }
247                 }
248             }
249         }
250         // Bind inner classes
251         for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
252             boolean hasQualifier = Stream.of(inner.getAnnotations())
253                     .anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
254             if (hasQualifier) {
255                 bindImplicit(inner);
256             }
257         }
258         // Bind inner providers
259         for (Method method : key.getRawType().getDeclaredMethods()) {
260             if (method.isAnnotationPresent(Provides.class)) {
261                 if (method.getTypeParameters().length != 0) {
262                     throw new DIException("Parameterized method are not supported " + method);
263                 }
264                 Object qualifier = ReflectionUtils.qualifierOf(method);
265                 Annotation scope = ReflectionUtils.scopeOf(method);
266                 Type returnType = method.getGenericReturnType();
267                 Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
268                 Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
269                 for (Type t : Types.getAllSuperTypes(returnType)) {
270                     if (types == null || types.contains(Types.getRawType(t))) {
271                         bind(Key.ofType(t, qualifier), bind);
272                         if (qualifier != null) {
273                             bind(Key.ofType(t), bind);
274                         }
275                     }
276                 }
277             }
278         }
279     }
280 
281     private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
282         if (typed != null) {
283             Class<?>[] typesArray = typed.value();
284             if (typesArray == null || typesArray.length == 0) {
285                 Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
286                 types.add(Object.class);
287                 return types;
288             } else {
289                 return new HashSet<>(Arrays.asList(typesArray));
290             }
291         } else {
292             return null;
293         }
294     }
295 
296     protected <K, V> Map<K, V> map(Map<K, Supplier<V>> map) {
297         return new WrappingMap<>(map, Supplier::get);
298     }
299 
300     private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
301 
302         private final Map<K, T> delegate;
303         private final Function<T, V> mapper;
304 
305         WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
306             this.delegate = delegate;
307             this.mapper = mapper;
308         }
309 
310         @SuppressWarnings("NullableProblems")
311         @Override
312         public Set<Entry<K, V>> entrySet() {
313             return new AbstractSet<Entry<K, V>>() {
314                 @Override
315                 public Iterator<Entry<K, V>> iterator() {
316                     Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
317                     return new Iterator<Entry<K, V>>() {
318                         @Override
319                         public boolean hasNext() {
320                             return it.hasNext();
321                         }
322 
323                         @Override
324                         public Entry<K, V> next() {
325                             Entry<K, T> n = it.next();
326                             return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
327                         }
328                     };
329                 }
330 
331                 @Override
332                 public int size() {
333                     return delegate.size();
334                 }
335             };
336         }
337     }
338 
339     protected <T> List<T> list(List<Supplier<T>> bindingList) {
340         return new WrappingList<>(bindingList, Supplier::get);
341     }
342 
343     private static class WrappingList<Q, T> extends AbstractList<Q> {
344 
345         private final List<T> delegate;
346         private final Function<T, Q> mapper;
347 
348         WrappingList(List<T> delegate, Function<T, Q> mapper) {
349             this.delegate = delegate;
350             this.mapper = mapper;
351         }
352 
353         @Override
354         public Q get(int index) {
355             return mapper.apply(delegate.get(index));
356         }
357 
358         @Override
359         public int size() {
360             return delegate.size();
361         }
362     }
363 
364     private static class SingletonScope implements Scope {
365         Map<Key<?>, java.util.function.Supplier<?>> cache = new ConcurrentHashMap<>();
366 
367         @SuppressWarnings("unchecked")
368         @Override
369         public <T> java.util.function.Supplier<T> scope(
370                 Key<T> key, Annotation scope, java.util.function.Supplier<T> unscoped) {
371             return (java.util.function.Supplier<T>)
372                     cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
373                         volatile T instance;
374 
375                         @Override
376                         public T get() {
377                             if (instance == null) {
378                                 synchronized (this) {
379                                     if (instance == null) {
380                                         instance = unscoped.get();
381                                     }
382                                 }
383                             }
384                             return instance;
385                         }
386                     });
387         }
388     }
389 }