View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.maven.di.impl;
20  
21  import java.io.BufferedReader;
22  import java.io.InputStream;
23  import java.io.InputStreamReader;
24  import java.lang.annotation.Annotation;
25  import java.lang.reflect.Method;
26  import java.lang.reflect.Modifier;
27  import java.lang.reflect.Type;
28  import java.net.URL;
29  import java.util.AbstractList;
30  import java.util.AbstractMap;
31  import java.util.AbstractSet;
32  import java.util.ArrayList;
33  import java.util.Arrays;
34  import java.util.Enumeration;
35  import java.util.HashMap;
36  import java.util.HashSet;
37  import java.util.Iterator;
38  import java.util.LinkedHashSet;
39  import java.util.List;
40  import java.util.Map;
41  import java.util.Objects;
42  import java.util.Set;
43  import java.util.concurrent.ConcurrentHashMap;
44  import java.util.function.Function;
45  import java.util.function.Supplier;
46  import java.util.stream.Collectors;
47  import java.util.stream.Stream;
48  
49  import org.apache.maven.api.annotations.Nonnull;
50  import org.apache.maven.api.di.Provides;
51  import org.apache.maven.api.di.Qualifier;
52  import org.apache.maven.api.di.Singleton;
53  import org.apache.maven.api.di.Typed;
54  import org.apache.maven.di.Injector;
55  import org.apache.maven.di.Key;
56  import org.apache.maven.di.Scope;
57  
58  import static org.apache.maven.di.impl.Binding.getPriorityComparator;
59  
60  public class InjectorImpl implements Injector {
61  
62      private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
63      private final Map<Class<? extends Annotation>, Supplier<Scope>> scopes = new HashMap<>();
64      private final Set<String> loadedUrls = new HashSet<>();
65      private final ThreadLocal<Set<Key<?>>> resolutionStack = new ThreadLocal<>();
66  
67      public InjectorImpl() {
68          bindScope(Singleton.class, new SingletonScope());
69      }
70  
71      @Nonnull
72      @Override
73      public <T> T getInstance(@Nonnull Class<T> key) {
74          return getInstance(Key.of(key));
75      }
76  
77      @Nonnull
78      @Override
79      public <T> T getInstance(@Nonnull Key<T> key) {
80          return getCompiledBinding(new Dependency<>(key, false)).get();
81      }
82  
83      @SuppressWarnings("unchecked")
84      @Override
85      public <T> void injectInstance(@Nonnull T instance) {
86          ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
87                  .compile(this::getCompiledBinding)
88                  .accept(instance);
89      }
90  
91      @Nonnull
92      @Override
93      public Injector discover(@Nonnull ClassLoader classLoader) {
94          try {
95              Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
96              while (enumeration.hasMoreElements()) {
97                  URL url = enumeration.nextElement();
98                  if (loadedUrls.add(url.toExternalForm())) {
99                      try (InputStream is = url.openStream();
100                             BufferedReader reader =
101                                     new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
102                         for (String line :
103                                 reader.lines().filter(l -> !l.startsWith("#")).toList()) {
104                             Class<?> clazz = classLoader.loadClass(line);
105                             bindImplicit(clazz);
106                         }
107                     }
108                 }
109             }
110         } catch (Exception e) {
111             throw new DIException("Error while discovering DI classes from classLoader", e);
112         }
113         return this;
114     }
115 
116     @Nonnull
117     @Override
118     public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Scope scope) {
119         return bindScope(scopeAnnotation, () -> scope);
120     }
121 
122     @Nonnull
123     @Override
124     public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Supplier<Scope> scope) {
125         if (scopes.put(scopeAnnotation, scope) != null) {
126             throw new DIException(
127                     "Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
128         }
129         return this;
130     }
131 
132     @Nonnull
133     @Override
134     public <U> Injector bindInstance(@Nonnull Class<U> clazz, @Nonnull U instance) {
135         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
136         Binding<U> binding = Binding.toInstance(instance);
137         return doBind(key, binding);
138     }
139 
140     @Override
141     public <U> Injector bindSupplier(@Nonnull Class<U> clazz, @Nonnull Supplier<U> supplier) {
142         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
143         Binding<U> binding = Binding.toSupplier(supplier);
144         return doBind(key, binding);
145     }
146 
147     @Nonnull
148     @Override
149     public Injector bindImplicit(@Nonnull Class<?> clazz) {
150         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
151         if (clazz.isInterface()) {
152             bindings.computeIfAbsent(key, $ -> new HashSet<>());
153             if (key.getQualifier() != null) {
154                 bindings.computeIfAbsent(Key.ofType(clazz), $ -> new HashSet<>());
155             }
156         } else if (!Modifier.isAbstract(clazz.getModifiers())) {
157             Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
158             doBind(key, binding);
159         }
160         return this;
161     }
162 
163     private final LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
164 
165     private Injector doBind(Key<?> key, Binding<?> binding) {
166         if (!current.add(key)) {
167             current.add(key);
168             throw new DIException("Circular references: " + current);
169         }
170         try {
171             doBindImplicit(key, binding);
172             Class<?> cls = key.getRawType().getSuperclass();
173             while (cls != Object.class && cls != null) {
174                 doBindImplicit(Key.of(cls, key.getQualifier()), binding);
175                 if (key.getQualifier() != null) {
176                     bind(Key.ofType(cls), binding);
177                 }
178                 cls = cls.getSuperclass();
179             }
180             return this;
181         } finally {
182             current.remove(key);
183         }
184     }
185 
186     protected <U> Injector bind(Key<U> key, Binding<U> b) {
187         Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
188         bindingSet.add(b);
189         return this;
190     }
191 
192     @SuppressWarnings({"unchecked", "rawtypes"})
193     protected <T> Set<Binding<T>> getBindings(Key<T> key) {
194         return (Set) bindings.get(key);
195     }
196 
197     protected Set<Key<?>> getBoundKeys() {
198         return bindings.keySet();
199     }
200 
201     public Map<Key<?>, Set<Binding<?>>> getBindings() {
202         return bindings;
203     }
204 
205     public <T> Set<Binding<T>> getAllBindings(Class<T> clazz) {
206         return getBindings(Key.of(clazz));
207     }
208 
209     public <Q> Supplier<Q> getCompiledBinding(Dependency<Q> dep) {
210         Key<Q> key = dep.key();
211         Supplier<Q> originalSupplier = doGetCompiledBinding(dep);
212         return () -> {
213             checkCyclicDependency(key);
214             try {
215                 return originalSupplier.get();
216             } finally {
217                 removeFromResolutionStack(key);
218             }
219         };
220     }
221 
222     public <Q> Supplier<Q> doGetCompiledBinding(Dependency<Q> dep) {
223         Key<Q> key = dep.key();
224         Set<Binding<Q>> res = getBindings(key);
225         if (res != null && !res.isEmpty()) {
226             List<Binding<Q>> bindingList = new ArrayList<>(res);
227             bindingList.sort(getPriorityComparator());
228             Binding<Q> binding = bindingList.get(0);
229             return compile(binding);
230         }
231         if (key.getRawType() == List.class) {
232             Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
233             if (res2 != null) {
234                 // Sort bindings by priority (highest first) for deterministic ordering
235                 List<Binding<Object>> sortedBindings = new ArrayList<>(res2);
236                 sortedBindings.sort(getPriorityComparator());
237 
238                 List<Supplier<Object>> list =
239                         sortedBindings.stream().map(this::compile).collect(Collectors.toList());
240                 //noinspection unchecked
241                 return () -> (Q) list(list, Supplier::get);
242             }
243         }
244         if (key.getRawType() == Map.class) {
245             Key<?> k = key.getTypeParameter(0);
246             Key<Object> v = key.getTypeParameter(1);
247             Set<Binding<Object>> res2 = getBindings(v);
248             if (k.getRawType() == String.class && res2 != null) {
249                 Map<String, Supplier<Object>> map = res2.stream()
250                         .filter(b -> b.getOriginalKey() == null
251                                 || b.getOriginalKey().getQualifier() == null
252                                 || b.getOriginalKey().getQualifier() instanceof String)
253                         .collect(Collectors.toMap(
254                                 b -> (String)
255                                         (b.getOriginalKey() != null
256                                                 ? b.getOriginalKey().getQualifier()
257                                                 : null),
258                                 this::compile));
259                 //noinspection unchecked
260                 return () -> (Q) map(map, Supplier::get);
261             }
262         }
263         if (dep.optional()) {
264             return () -> null;
265         }
266         throw new DIException("No binding to construct an instance for key "
267                 + key.getDisplayString() + ".  Existing bindings:\n"
268                 + getBoundKeys().stream()
269                         .map(Key::toString)
270                         .map(String::trim)
271                         .sorted()
272                         .distinct()
273                         .collect(Collectors.joining("\n - ", " - ", "")));
274     }
275 
276     @SuppressWarnings("unchecked")
277     protected <Q> Supplier<Q> compile(Binding<Q> binding) {
278         Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
279         if (binding.getScope() != null) {
280             Scope scope = scopes.entrySet().stream()
281                     .filter(e -> e.getKey().isInstance(binding.getScope()))
282                     .map(Map.Entry::getValue)
283                     .findFirst()
284                     .orElseThrow(() -> new DIException("Scope not bound for annotation "
285                             + binding.getScope().annotationType()))
286                     .get();
287             compiled = scope.scope((Key<Q>) binding.getOriginalKey(), compiled);
288         }
289         return compiled;
290     }
291 
292     protected void doBindImplicit(Key<?> key, Binding<?> binding) {
293         if (binding != null) {
294             // For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
295             Object qualifier = key.getQualifier();
296             Class<?> type = key.getRawType();
297             Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
298             for (Type t : Types.getAllSuperTypes(type)) {
299                 if (types == null || types.contains(Types.getRawType(t))) {
300                     bind(Key.ofType(t, qualifier), binding);
301                     if (qualifier != null) {
302                         bind(Key.ofType(t), binding);
303                     }
304                 }
305             }
306         }
307         // Bind inner classes
308         for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
309             boolean hasQualifier = Stream.of(inner.getAnnotations())
310                     .anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
311             if (hasQualifier) {
312                 bindImplicit(inner);
313             }
314         }
315         // Bind inner providers
316         for (Method method : key.getRawType().getDeclaredMethods()) {
317             if (method.isAnnotationPresent(Provides.class)) {
318                 if (method.getTypeParameters().length != 0) {
319                     throw new DIException("Parameterized method are not supported " + method);
320                 }
321                 Object qualifier = ReflectionUtils.qualifierOf(method);
322                 Annotation scope = ReflectionUtils.scopeOf(method);
323                 Type returnType = method.getGenericReturnType();
324                 Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
325                 Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
326                 for (Type t : Types.getAllSuperTypes(returnType)) {
327                     if (types == null || types.contains(Types.getRawType(t))) {
328                         bind(Key.ofType(t, qualifier), bind);
329                         if (qualifier != null) {
330                             bind(Key.ofType(t), bind);
331                         }
332                     }
333                 }
334             }
335         }
336     }
337 
338     private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
339         if (typed != null) {
340             Class<?>[] typesArray = typed.value();
341             if (typesArray == null || typesArray.length == 0) {
342                 Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
343                 types.add(Object.class);
344                 return types;
345             } else {
346                 return new HashSet<>(Arrays.asList(typesArray));
347             }
348         } else {
349             return null;
350         }
351     }
352 
353     protected <K, V, T> Map<K, V> map(Map<K, T> map, Function<T, V> mapper) {
354         return new WrappingMap<>(map, mapper);
355     }
356 
357     private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
358 
359         private final Map<K, T> delegate;
360         private final Function<T, V> mapper;
361 
362         WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
363             this.delegate = delegate;
364             this.mapper = mapper;
365         }
366 
367         @Override
368         public Set<Entry<K, V>> entrySet() {
369             return new AbstractSet<>() {
370                 @Override
371                 public Iterator<Entry<K, V>> iterator() {
372                     Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
373                     return new Iterator<>() {
374                         @Override
375                         public boolean hasNext() {
376                             return it.hasNext();
377                         }
378 
379                         @Override
380                         public Entry<K, V> next() {
381                             Entry<K, T> n = it.next();
382                             return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
383                         }
384                     };
385                 }
386 
387                 @Override
388                 public int size() {
389                     return delegate.size();
390                 }
391             };
392         }
393     }
394 
395     protected <Q, T> List<Q> list(List<T> bindingList, Function<T, Q> mapper) {
396         return new WrappingList<>(bindingList, mapper);
397     }
398 
399     private static class WrappingList<Q, T> extends AbstractList<Q> {
400 
401         private final List<T> delegate;
402         private final Function<T, Q> mapper;
403 
404         WrappingList(List<T> delegate, Function<T, Q> mapper) {
405             this.delegate = delegate;
406             this.mapper = mapper;
407         }
408 
409         @Override
410         public Q get(int index) {
411             return mapper.apply(delegate.get(index));
412         }
413 
414         @Override
415         public int size() {
416             return delegate.size();
417         }
418     }
419 
420     private void checkCyclicDependency(Key<?> key) {
421         Set<Key<?>> stack = resolutionStack.get();
422         if (stack == null) {
423             stack = new LinkedHashSet<>();
424             resolutionStack.set(stack);
425         }
426         if (!stack.add(key)) {
427             throw new DIException("Cyclic dependency detected: "
428                     + stack.stream().map(Key::getDisplayString).collect(Collectors.joining(" -> "))
429                     + " -> "
430                     + key.getDisplayString());
431         }
432     }
433 
434     private void removeFromResolutionStack(Key<?> key) {
435         Set<Key<?>> stack = resolutionStack.get();
436         if (stack != null) {
437             stack.remove(key);
438             if (stack.isEmpty()) {
439                 resolutionStack.remove();
440             }
441         }
442     }
443 
444     private static class SingletonScope implements Scope {
445         Map<Key<?>, java.util.function.Supplier<?>> cache = new ConcurrentHashMap<>();
446 
447         @Nonnull
448         @SuppressWarnings("unchecked")
449         @Override
450         public <T> java.util.function.Supplier<T> scope(
451                 @Nonnull Key<T> key, @Nonnull java.util.function.Supplier<T> unscoped) {
452             return (java.util.function.Supplier<T>)
453                     cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
454                         volatile T instance;
455 
456                         @Override
457                         public T get() {
458                             if (instance == null) {
459                                 synchronized (this) {
460                                     if (instance == null) {
461                                         instance = unscoped.get();
462                                     }
463                                 }
464                             }
465                             return instance;
466                         }
467                     });
468         }
469     }
470 }