View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.maven.di.impl;
20  
21  import java.io.BufferedReader;
22  import java.io.InputStream;
23  import java.io.InputStreamReader;
24  import java.lang.annotation.Annotation;
25  import java.lang.reflect.Method;
26  import java.lang.reflect.Type;
27  import java.net.URL;
28  import java.util.*;
29  import java.util.function.Function;
30  import java.util.function.Supplier;
31  import java.util.stream.Collectors;
32  import java.util.stream.Stream;
33  
34  import org.apache.maven.api.di.Provides;
35  import org.apache.maven.api.di.Qualifier;
36  import org.apache.maven.api.di.Singleton;
37  import org.apache.maven.api.di.Typed;
38  import org.apache.maven.di.Injector;
39  import org.apache.maven.di.Key;
40  import org.apache.maven.di.Scope;
41  
42  public class InjectorImpl implements Injector {
43  
44      private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
45      private final Map<Class<? extends Annotation>, Scope> scopes = new HashMap<>();
46  
47      public InjectorImpl() {
48          bindScope(Singleton.class, new SingletonScope());
49      }
50  
51      public <T> T getInstance(Class<T> key) {
52          return getInstance(Key.of(key));
53      }
54  
55      public <T> T getInstance(Key<T> key) {
56          return getCompiledBinding(key).get();
57      }
58  
59      @SuppressWarnings("unchecked")
60      @Override
61      public <T> void injectInstance(T instance) {
62          ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
63                  .compile(this::getCompiledBinding)
64                  .accept(instance);
65      }
66  
67      @Override
68      public Injector discover(ClassLoader classLoader) {
69          try {
70              Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
71              while (enumeration.hasMoreElements()) {
72                  try (InputStream is = enumeration.nextElement().openStream();
73                          BufferedReader reader = new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
74                      for (String line :
75                              reader.lines().filter(l -> !l.startsWith("#")).collect(Collectors.toList())) {
76                          Class<?> clazz = classLoader.loadClass(line);
77                          bindImplicit(clazz);
78                      }
79                  }
80              }
81          } catch (Exception e) {
82              throw new DIException("Error while discovering DI classes from classLoader", e);
83          }
84          return this;
85      }
86  
87      public Injector bindScope(Class<? extends Annotation> scopeAnnotation, Scope scope) {
88          if (scopes.put(scopeAnnotation, scope) != null) {
89              throw new DIException(
90                      "Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
91          }
92          return this;
93      }
94  
95      public <U> Injector bindInstance(Class<U> clazz, U instance) {
96          Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
97          Binding<U> binding = Binding.toInstance(instance);
98          return doBind(key, binding);
99      }
100 
101     @Override
102     public Injector bindImplicit(Class<?> clazz) {
103         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
104         Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
105         return doBind(key, binding);
106     }
107 
108     private LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
109 
110     private Injector doBind(Key<?> key, Binding<?> binding) {
111         if (!current.add(key)) {
112             current.add(key);
113             throw new DIException("Circular references: " + current);
114         }
115         doBindImplicit(key, binding);
116         Class<?> cls = key.getRawType().getSuperclass();
117         while (cls != Object.class && cls != null) {
118             key = Key.of(cls, key.getQualifier());
119             doBindImplicit(key, binding);
120             cls = cls.getSuperclass();
121         }
122         current.remove(key);
123         return this;
124     }
125 
126     protected <U> Injector bind(Key<U> key, Binding<U> b) {
127         Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
128         bindingSet.add(b);
129         return this;
130     }
131 
132     @SuppressWarnings({"unchecked", "rawtypes"})
133     private <T> Set<Binding<T>> getBindings(Key<T> key) {
134         return (Set) bindings.get(key);
135     }
136 
137     public <Q> Supplier<Q> getCompiledBinding(Key<Q> key) {
138         Set<Binding<Q>> res = getBindings(key);
139         if (res != null) {
140             List<Binding<Q>> bindingList = new ArrayList<>(res);
141             Comparator<Binding<Q>> comparing = Comparator.comparing(Binding::getPriority);
142             bindingList.sort(comparing.reversed());
143             Binding<Q> binding = bindingList.get(0);
144             return compile(binding);
145         }
146         if (key.getRawType() == List.class) {
147             Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
148             if (res2 != null) {
149                 List<Supplier<Object>> bindingList =
150                         res2.stream().map(this::compile).collect(Collectors.toList());
151                 //noinspection unchecked
152                 return () -> (Q) new WrappingList<>(bindingList, Supplier::get);
153             }
154         }
155         if (key.getRawType() == Map.class) {
156             Key<?> k = key.getTypeParameter(0);
157             Key<Object> v = key.getTypeParameter(1);
158             Set<Binding<Object>> res2 = getBindings(v);
159             if (k.getRawType() == String.class && res2 != null) {
160                 Map<String, Supplier<Object>> map = res2.stream()
161                         .filter(b -> b.getOriginalKey().getQualifier() == null
162                                 || b.getOriginalKey().getQualifier() instanceof String)
163                         .collect(Collectors.toMap(
164                                 b -> (String) b.getOriginalKey().getQualifier(), this::compile));
165                 //noinspection unchecked
166                 return (() -> (Q) new WrappingMap<>(map, Supplier::get));
167             }
168         }
169         throw new DIException("No binding to construct an instance for key "
170                 + key.getDisplayString() + ".  Existing bindings:\n"
171                 + bindings.keySet().stream().map(Key::toString).collect(Collectors.joining("\n - ", " - ", "")));
172     }
173 
174     @SuppressWarnings("unchecked")
175     private <Q> Supplier<Q> compile(Binding<Q> binding) {
176         Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
177         if (binding.getScope() != null) {
178             Scope scope = scopes.entrySet().stream()
179                     .filter(e -> e.getKey().isInstance(binding.getScope()))
180                     .map(Map.Entry::getValue)
181                     .findFirst()
182                     .orElseThrow(() -> new DIException("Scope not bound for annotation "
183                             + binding.getScope().getClass()));
184             compiled = scope.scope((Key<Q>) binding.getOriginalKey(), binding.getScope(), compiled);
185         }
186         return compiled;
187     }
188 
189     protected void doBindImplicit(Key<?> key, Binding<?> binding) {
190         if (binding != null) {
191             // For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
192             Object qualifier = key.getQualifier();
193             Class<?> type = key.getRawType();
194             Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
195             for (Type t : Types.getAllSuperTypes(type)) {
196                 if (types == null || types.contains(Types.getRawType(t))) {
197                     bind(Key.ofType(t, qualifier), binding);
198                     if (qualifier != null) {
199                         bind(Key.ofType(t), binding);
200                     }
201                 }
202             }
203         }
204         // Bind inner classes
205         for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
206             boolean hasQualifier = Stream.of(inner.getAnnotations())
207                     .anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
208             if (hasQualifier) {
209                 bindImplicit(inner);
210             }
211         }
212         // Bind inner providers
213         for (Method method : key.getRawType().getDeclaredMethods()) {
214             if (method.isAnnotationPresent(Provides.class)) {
215                 if (method.getTypeParameters().length != 0) {
216                     throw new DIException("Parameterized method are not supported " + method);
217                 }
218                 Object qualifier = ReflectionUtils.qualifierOf(method);
219                 Annotation scope = ReflectionUtils.scopeOf(method);
220                 Type returnType = method.getGenericReturnType();
221                 Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
222                 Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
223                 for (Type t : Types.getAllSuperTypes(returnType)) {
224                     if (types == null || types.contains(Types.getRawType(t))) {
225                         bind(Key.ofType(t, qualifier), bind);
226                         if (qualifier != null) {
227                             bind(Key.ofType(t), bind);
228                         }
229                     }
230                 }
231             }
232         }
233     }
234 
235     private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
236         if (typed != null) {
237             Class<?>[] typesArray = typed.value();
238             if (typesArray == null || typesArray.length == 0) {
239                 Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
240                 types.add(Object.class);
241                 return types;
242             } else {
243                 return new HashSet<>(Arrays.asList(typesArray));
244             }
245         } else {
246             return null;
247         }
248     }
249 
250     private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
251 
252         private final Map<K, T> delegate;
253         private final Function<T, V> mapper;
254 
255         WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
256             this.delegate = delegate;
257             this.mapper = mapper;
258         }
259 
260         @SuppressWarnings("NullableProblems")
261         @Override
262         public Set<Entry<K, V>> entrySet() {
263             return new AbstractSet<Entry<K, V>>() {
264                 @Override
265                 public Iterator<Entry<K, V>> iterator() {
266                     Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
267                     return new Iterator<Entry<K, V>>() {
268                         @Override
269                         public boolean hasNext() {
270                             return it.hasNext();
271                         }
272 
273                         @Override
274                         public Entry<K, V> next() {
275                             Entry<K, T> n = it.next();
276                             return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
277                         }
278                     };
279                 }
280 
281                 @Override
282                 public int size() {
283                     return delegate.size();
284                 }
285             };
286         }
287     }
288 
289     private static class WrappingList<Q, T> extends AbstractList<Q> {
290 
291         private final List<T> delegate;
292         private final Function<T, Q> mapper;
293 
294         WrappingList(List<T> delegate, Function<T, Q> mapper) {
295             this.delegate = delegate;
296             this.mapper = mapper;
297         }
298 
299         @Override
300         public Q get(int index) {
301             return mapper.apply(delegate.get(index));
302         }
303 
304         @Override
305         public int size() {
306             return delegate.size();
307         }
308     }
309 
310     private static class SingletonScope implements Scope {
311         Map<Key<?>, java.util.function.Supplier<?>> cache = new HashMap<>();
312 
313         @SuppressWarnings("unchecked")
314         @Override
315         public <T> java.util.function.Supplier<T> scope(
316                 Key<T> key, Annotation scope, java.util.function.Supplier<T> unscoped) {
317             return (java.util.function.Supplier<T>)
318                     cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
319                         volatile T instance;
320 
321                         @Override
322                         public T get() {
323                             if (instance == null) {
324                                 synchronized (this) {
325                                     if (instance == null) {
326                                         instance = unscoped.get();
327                                     }
328                                 }
329                             }
330                             return instance;
331                         }
332                     });
333         }
334     }
335 }