View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.maven.di.impl;
20  
21  import java.io.BufferedReader;
22  import java.io.InputStream;
23  import java.io.InputStreamReader;
24  import java.lang.annotation.Annotation;
25  import java.lang.reflect.Method;
26  import java.lang.reflect.Modifier;
27  import java.lang.reflect.Type;
28  import java.net.URL;
29  import java.util.AbstractList;
30  import java.util.AbstractMap;
31  import java.util.AbstractSet;
32  import java.util.ArrayList;
33  import java.util.Arrays;
34  import java.util.Comparator;
35  import java.util.Enumeration;
36  import java.util.HashMap;
37  import java.util.HashSet;
38  import java.util.Iterator;
39  import java.util.LinkedHashSet;
40  import java.util.List;
41  import java.util.Map;
42  import java.util.Objects;
43  import java.util.Set;
44  import java.util.concurrent.ConcurrentHashMap;
45  import java.util.function.Function;
46  import java.util.function.Supplier;
47  import java.util.stream.Collectors;
48  import java.util.stream.Stream;
49  
50  import org.apache.maven.api.di.Provides;
51  import org.apache.maven.api.di.Qualifier;
52  import org.apache.maven.api.di.Singleton;
53  import org.apache.maven.api.di.Typed;
54  import org.apache.maven.di.Injector;
55  import org.apache.maven.di.Key;
56  import org.apache.maven.di.Scope;
57  
58  public class InjectorImpl implements Injector {
59  
60      private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
61      private final Map<Class<? extends Annotation>, Supplier<Scope>> scopes = new HashMap<>();
62      private final Set<String> loadedUrls = new HashSet<>();
63  
64      public InjectorImpl() {
65          bindScope(Singleton.class, new SingletonScope());
66      }
67  
68      @Override
69      public <T> T getInstance(Class<T> key) {
70          return getInstance(Key.of(key));
71      }
72  
73      @Override
74      public <T> T getInstance(Key<T> key) {
75          return getCompiledBinding(new Dependency<>(key, false)).get();
76      }
77  
78      @SuppressWarnings("unchecked")
79      @Override
80      public <T> void injectInstance(T instance) {
81          ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
82                  .compile(this::getCompiledBinding)
83                  .accept(instance);
84      }
85  
86      @Override
87      public Injector discover(ClassLoader classLoader) {
88          try {
89              Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
90              while (enumeration.hasMoreElements()) {
91                  URL url = enumeration.nextElement();
92                  if (loadedUrls.add(url.toExternalForm())) {
93                      try (InputStream is = url.openStream();
94                              BufferedReader reader =
95                                      new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
96                          for (String line :
97                                  reader.lines().filter(l -> !l.startsWith("#")).toList()) {
98                              Class<?> clazz = classLoader.loadClass(line);
99                              bindImplicit(clazz);
100                         }
101                     }
102                 }
103             }
104         } catch (Exception e) {
105             throw new DIException("Error while discovering DI classes from classLoader", e);
106         }
107         return this;
108     }
109 
110     @Override
111     public Injector bindScope(Class<? extends Annotation> scopeAnnotation, Scope scope) {
112         return bindScope(scopeAnnotation, () -> scope);
113     }
114 
115     @Override
116     public Injector bindScope(Class<? extends Annotation> scopeAnnotation, Supplier<Scope> scope) {
117         if (scopes.put(scopeAnnotation, scope) != null) {
118             throw new DIException(
119                     "Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
120         }
121         return this;
122     }
123 
124     @Override
125     public <U> Injector bindInstance(Class<U> clazz, U instance) {
126         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
127         Binding<U> binding = Binding.toInstance(instance);
128         return doBind(key, binding);
129     }
130 
131     @Override
132     public Injector bindImplicit(Class<?> clazz) {
133         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
134         if (clazz.isInterface()) {
135             bindings.computeIfAbsent(key, $ -> new HashSet<>());
136             if (key.getQualifier() != null) {
137                 bindings.computeIfAbsent(Key.ofType(clazz), $ -> new HashSet<>());
138             }
139         } else if (!Modifier.isAbstract(clazz.getModifiers())) {
140             Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
141             doBind(key, binding);
142         }
143         return this;
144     }
145 
146     private final LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
147 
148     private Injector doBind(Key<?> key, Binding<?> binding) {
149         if (!current.add(key)) {
150             current.add(key);
151             throw new DIException("Circular references: " + current);
152         }
153         try {
154             doBindImplicit(key, binding);
155             Class<?> cls = key.getRawType().getSuperclass();
156             while (cls != Object.class && cls != null) {
157                 doBindImplicit(Key.of(cls, key.getQualifier()), binding);
158                 if (key.getQualifier() != null) {
159                     bind(Key.ofType(cls), binding);
160                 }
161                 cls = cls.getSuperclass();
162             }
163             return this;
164         } finally {
165             current.remove(key);
166         }
167     }
168 
169     protected <U> Injector bind(Key<U> key, Binding<U> b) {
170         Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
171         bindingSet.add(b);
172         return this;
173     }
174 
175     @SuppressWarnings({"unchecked", "rawtypes"})
176     protected <T> Set<Binding<T>> getBindings(Key<T> key) {
177         return (Set) bindings.get(key);
178     }
179 
180     protected Set<Key<?>> getBoundKeys() {
181         return bindings.keySet();
182     }
183 
184     public Map<Key<?>, Set<Binding<?>>> getBindings() {
185         return bindings;
186     }
187 
188     public <Q> Supplier<Q> getCompiledBinding(Dependency<Q> dep) {
189         Key<Q> key = dep.key();
190         Set<Binding<Q>> res = getBindings(key);
191         if (res != null && !res.isEmpty()) {
192             List<Binding<Q>> bindingList = new ArrayList<>(res);
193             Comparator<Binding<Q>> comparing = Comparator.comparing(Binding::getPriority);
194             bindingList.sort(comparing.reversed());
195             Binding<Q> binding = bindingList.get(0);
196             return compile(binding);
197         }
198         if (key.getRawType() == List.class) {
199             Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
200             if (res2 != null) {
201                 List<Supplier<Object>> list = res2.stream().map(this::compile).collect(Collectors.toList());
202                 //noinspection unchecked
203                 return () -> (Q) list(list, Supplier::get);
204             }
205         }
206         if (key.getRawType() == Map.class) {
207             Key<?> k = key.getTypeParameter(0);
208             Key<Object> v = key.getTypeParameter(1);
209             Set<Binding<Object>> res2 = getBindings(v);
210             if (k.getRawType() == String.class && res2 != null) {
211                 Map<String, Supplier<Object>> map = res2.stream()
212                         .filter(b -> b.getOriginalKey() == null
213                                 || b.getOriginalKey().getQualifier() == null
214                                 || b.getOriginalKey().getQualifier() instanceof String)
215                         .collect(Collectors.toMap(
216                                 b -> (String)
217                                         (b.getOriginalKey() != null
218                                                 ? b.getOriginalKey().getQualifier()
219                                                 : null),
220                                 this::compile));
221                 //noinspection unchecked
222                 return () -> (Q) map(map, Supplier::get);
223             }
224         }
225         if (dep.optional()) {
226             return () -> null;
227         }
228         throw new DIException("No binding to construct an instance for key "
229                 + key.getDisplayString() + ".  Existing bindings:\n"
230                 + getBoundKeys().stream()
231                         .map(Key::toString)
232                         .map(String::trim)
233                         .sorted()
234                         .distinct()
235                         .collect(Collectors.joining("\n - ", " - ", "")));
236     }
237 
238     @SuppressWarnings("unchecked")
239     protected <Q> Supplier<Q> compile(Binding<Q> binding) {
240         Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
241         if (binding.getScope() != null) {
242             Scope scope = scopes.entrySet().stream()
243                     .filter(e -> e.getKey().isInstance(binding.getScope()))
244                     .map(Map.Entry::getValue)
245                     .findFirst()
246                     .orElseThrow(() -> new DIException("Scope not bound for annotation "
247                             + binding.getScope().annotationType()))
248                     .get();
249             compiled = scope.scope((Key<Q>) binding.getOriginalKey(), compiled);
250         }
251         return compiled;
252     }
253 
254     protected void doBindImplicit(Key<?> key, Binding<?> binding) {
255         if (binding != null) {
256             // For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
257             Object qualifier = key.getQualifier();
258             Class<?> type = key.getRawType();
259             Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
260             for (Type t : Types.getAllSuperTypes(type)) {
261                 if (types == null || types.contains(Types.getRawType(t))) {
262                     bind(Key.ofType(t, qualifier), binding);
263                     if (qualifier != null) {
264                         bind(Key.ofType(t), binding);
265                     }
266                 }
267             }
268         }
269         // Bind inner classes
270         for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
271             boolean hasQualifier = Stream.of(inner.getAnnotations())
272                     .anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
273             if (hasQualifier) {
274                 bindImplicit(inner);
275             }
276         }
277         // Bind inner providers
278         for (Method method : key.getRawType().getDeclaredMethods()) {
279             if (method.isAnnotationPresent(Provides.class)) {
280                 if (method.getTypeParameters().length != 0) {
281                     throw new DIException("Parameterized method are not supported " + method);
282                 }
283                 Object qualifier = ReflectionUtils.qualifierOf(method);
284                 Annotation scope = ReflectionUtils.scopeOf(method);
285                 Type returnType = method.getGenericReturnType();
286                 Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
287                 Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
288                 for (Type t : Types.getAllSuperTypes(returnType)) {
289                     if (types == null || types.contains(Types.getRawType(t))) {
290                         bind(Key.ofType(t, qualifier), bind);
291                         if (qualifier != null) {
292                             bind(Key.ofType(t), bind);
293                         }
294                     }
295                 }
296             }
297         }
298     }
299 
300     private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
301         if (typed != null) {
302             Class<?>[] typesArray = typed.value();
303             if (typesArray == null || typesArray.length == 0) {
304                 Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
305                 types.add(Object.class);
306                 return types;
307             } else {
308                 return new HashSet<>(Arrays.asList(typesArray));
309             }
310         } else {
311             return null;
312         }
313     }
314 
315     protected <K, V, T> Map<K, V> map(Map<K, T> map, Function<T, V> mapper) {
316         return new WrappingMap<>(map, mapper);
317     }
318 
319     private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
320 
321         private final Map<K, T> delegate;
322         private final Function<T, V> mapper;
323 
324         WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
325             this.delegate = delegate;
326             this.mapper = mapper;
327         }
328 
329         @Override
330         public Set<Entry<K, V>> entrySet() {
331             return new AbstractSet<>() {
332                 @Override
333                 public Iterator<Entry<K, V>> iterator() {
334                     Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
335                     return new Iterator<>() {
336                         @Override
337                         public boolean hasNext() {
338                             return it.hasNext();
339                         }
340 
341                         @Override
342                         public Entry<K, V> next() {
343                             Entry<K, T> n = it.next();
344                             return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
345                         }
346                     };
347                 }
348 
349                 @Override
350                 public int size() {
351                     return delegate.size();
352                 }
353             };
354         }
355     }
356 
357     protected <Q, T> List<Q> list(List<T> bindingList, Function<T, Q> mapper) {
358         return new WrappingList<>(bindingList, mapper);
359     }
360 
361     private static class WrappingList<Q, T> extends AbstractList<Q> {
362 
363         private final List<T> delegate;
364         private final Function<T, Q> mapper;
365 
366         WrappingList(List<T> delegate, Function<T, Q> mapper) {
367             this.delegate = delegate;
368             this.mapper = mapper;
369         }
370 
371         @Override
372         public Q get(int index) {
373             return mapper.apply(delegate.get(index));
374         }
375 
376         @Override
377         public int size() {
378             return delegate.size();
379         }
380     }
381 
382     private static class SingletonScope implements Scope {
383         Map<Key<?>, java.util.function.Supplier<?>> cache = new ConcurrentHashMap<>();
384 
385         @SuppressWarnings("unchecked")
386         @Override
387         public <T> java.util.function.Supplier<T> scope(Key<T> key, java.util.function.Supplier<T> unscoped) {
388             return (java.util.function.Supplier<T>)
389                     cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
390                         volatile T instance;
391 
392                         @Override
393                         public T get() {
394                             if (instance == null) {
395                                 synchronized (this) {
396                                     if (instance == null) {
397                                         instance = unscoped.get();
398                                     }
399                                 }
400                             }
401                             return instance;
402                         }
403                     });
404         }
405     }
406 }