View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.apache.maven.di.impl;
20  
21  import java.io.BufferedReader;
22  import java.io.InputStream;
23  import java.io.InputStreamReader;
24  import java.lang.annotation.Annotation;
25  import java.lang.reflect.Method;
26  import java.lang.reflect.Modifier;
27  import java.lang.reflect.Type;
28  import java.net.URL;
29  import java.util.AbstractList;
30  import java.util.AbstractMap;
31  import java.util.AbstractSet;
32  import java.util.ArrayList;
33  import java.util.Arrays;
34  import java.util.Enumeration;
35  import java.util.HashMap;
36  import java.util.HashSet;
37  import java.util.Iterator;
38  import java.util.LinkedHashSet;
39  import java.util.List;
40  import java.util.Map;
41  import java.util.Objects;
42  import java.util.Set;
43  import java.util.concurrent.ConcurrentHashMap;
44  import java.util.function.Function;
45  import java.util.function.Supplier;
46  import java.util.stream.Collectors;
47  import java.util.stream.Stream;
48  
49  import org.apache.maven.api.annotations.Nonnull;
50  import org.apache.maven.api.di.Provides;
51  import org.apache.maven.api.di.Qualifier;
52  import org.apache.maven.api.di.Singleton;
53  import org.apache.maven.api.di.Typed;
54  import org.apache.maven.di.Injector;
55  import org.apache.maven.di.Key;
56  import org.apache.maven.di.Scope;
57  
58  import static org.apache.maven.di.impl.Binding.getPriorityComparator;
59  
60  public class InjectorImpl implements Injector {
61  
62      private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
63      private final Map<Class<? extends Annotation>, Supplier<Scope>> scopes = new HashMap<>();
64      private final Set<String> loadedUrls = new HashSet<>();
65      private final ThreadLocal<Set<Key<?>>> resolutionStack = new ThreadLocal<>();
66  
67      public InjectorImpl() {
68          bindScope(Singleton.class, new SingletonScope());
69      }
70  
71      @Nonnull
72      @Override
73      public <T> T getInstance(@Nonnull Class<T> key) {
74          return getInstance(Key.of(key));
75      }
76  
77      @Nonnull
78      @Override
79      public <T> T getInstance(@Nonnull Key<T> key) {
80          return getCompiledBinding(new Dependency<>(key, false)).get();
81      }
82  
83      @SuppressWarnings("unchecked")
84      @Override
85      public <T> void injectInstance(@Nonnull T instance) {
86          ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
87                  .compile(this::getCompiledBinding)
88                  .accept(instance);
89      }
90  
91      @Nonnull
92      @Override
93      public Injector discover(@Nonnull ClassLoader classLoader) {
94          try {
95              Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
96              while (enumeration.hasMoreElements()) {
97                  URL url = enumeration.nextElement();
98                  if (loadedUrls.add(url.toExternalForm())) {
99                      try (InputStream is = url.openStream();
100                             BufferedReader reader =
101                                     new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
102                         for (String line :
103                                 reader.lines().filter(l -> !l.startsWith("#")).toList()) {
104                             Class<?> clazz = classLoader.loadClass(line);
105                             bindImplicit(clazz);
106                         }
107                     }
108                 }
109             }
110         } catch (Exception e) {
111             throw new DIException("Error while discovering DI classes from classLoader", e);
112         }
113         return this;
114     }
115 
116     @Nonnull
117     @Override
118     public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Scope scope) {
119         return bindScope(scopeAnnotation, () -> scope);
120     }
121 
122     @Nonnull
123     @Override
124     public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Supplier<Scope> scope) {
125         if (scopes.put(scopeAnnotation, scope) != null) {
126             throw new DIException(
127                     "Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
128         }
129         return this;
130     }
131 
132     @Nonnull
133     @Override
134     public <U> Injector bindInstance(@Nonnull Class<U> clazz, @Nonnull U instance) {
135         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
136         Binding<U> binding = Binding.toInstance(instance);
137         return doBind(key, binding);
138     }
139 
140     @Nonnull
141     @Override
142     public Injector bindImplicit(@Nonnull Class<?> clazz) {
143         Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
144         if (clazz.isInterface()) {
145             bindings.computeIfAbsent(key, $ -> new HashSet<>());
146             if (key.getQualifier() != null) {
147                 bindings.computeIfAbsent(Key.ofType(clazz), $ -> new HashSet<>());
148             }
149         } else if (!Modifier.isAbstract(clazz.getModifiers())) {
150             Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
151             doBind(key, binding);
152         }
153         return this;
154     }
155 
156     private final LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
157 
158     private Injector doBind(Key<?> key, Binding<?> binding) {
159         if (!current.add(key)) {
160             current.add(key);
161             throw new DIException("Circular references: " + current);
162         }
163         try {
164             doBindImplicit(key, binding);
165             Class<?> cls = key.getRawType().getSuperclass();
166             while (cls != Object.class && cls != null) {
167                 doBindImplicit(Key.of(cls, key.getQualifier()), binding);
168                 if (key.getQualifier() != null) {
169                     bind(Key.ofType(cls), binding);
170                 }
171                 cls = cls.getSuperclass();
172             }
173             return this;
174         } finally {
175             current.remove(key);
176         }
177     }
178 
179     protected <U> Injector bind(Key<U> key, Binding<U> b) {
180         Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
181         bindingSet.add(b);
182         return this;
183     }
184 
185     @SuppressWarnings({"unchecked", "rawtypes"})
186     protected <T> Set<Binding<T>> getBindings(Key<T> key) {
187         return (Set) bindings.get(key);
188     }
189 
190     protected Set<Key<?>> getBoundKeys() {
191         return bindings.keySet();
192     }
193 
194     public Map<Key<?>, Set<Binding<?>>> getBindings() {
195         return bindings;
196     }
197 
198     public <Q> Supplier<Q> getCompiledBinding(Dependency<Q> dep) {
199         Key<Q> key = dep.key();
200         Supplier<Q> originalSupplier = doGetCompiledBinding(dep);
201         return () -> {
202             checkCyclicDependency(key);
203             try {
204                 return originalSupplier.get();
205             } finally {
206                 removeFromResolutionStack(key);
207             }
208         };
209     }
210 
211     public <Q> Supplier<Q> doGetCompiledBinding(Dependency<Q> dep) {
212         Key<Q> key = dep.key();
213         Set<Binding<Q>> res = getBindings(key);
214         if (res != null && !res.isEmpty()) {
215             List<Binding<Q>> bindingList = new ArrayList<>(res);
216             bindingList.sort(getPriorityComparator());
217             Binding<Q> binding = bindingList.get(0);
218             return compile(binding);
219         }
220         if (key.getRawType() == List.class) {
221             Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
222             if (res2 != null) {
223                 // Sort bindings by priority (highest first) for deterministic ordering
224                 List<Binding<Object>> sortedBindings = new ArrayList<>(res2);
225                 sortedBindings.sort(getPriorityComparator());
226 
227                 List<Supplier<Object>> list =
228                         sortedBindings.stream().map(this::compile).collect(Collectors.toList());
229                 //noinspection unchecked
230                 return () -> (Q) list(list, Supplier::get);
231             }
232         }
233         if (key.getRawType() == Map.class) {
234             Key<?> k = key.getTypeParameter(0);
235             Key<Object> v = key.getTypeParameter(1);
236             Set<Binding<Object>> res2 = getBindings(v);
237             if (k.getRawType() == String.class && res2 != null) {
238                 Map<String, Supplier<Object>> map = res2.stream()
239                         .filter(b -> b.getOriginalKey() == null
240                                 || b.getOriginalKey().getQualifier() == null
241                                 || b.getOriginalKey().getQualifier() instanceof String)
242                         .collect(Collectors.toMap(
243                                 b -> (String)
244                                         (b.getOriginalKey() != null
245                                                 ? b.getOriginalKey().getQualifier()
246                                                 : null),
247                                 this::compile));
248                 //noinspection unchecked
249                 return () -> (Q) map(map, Supplier::get);
250             }
251         }
252         if (dep.optional()) {
253             return () -> null;
254         }
255         throw new DIException("No binding to construct an instance for key "
256                 + key.getDisplayString() + ".  Existing bindings:\n"
257                 + getBoundKeys().stream()
258                         .map(Key::toString)
259                         .map(String::trim)
260                         .sorted()
261                         .distinct()
262                         .collect(Collectors.joining("\n - ", " - ", "")));
263     }
264 
265     @SuppressWarnings("unchecked")
266     protected <Q> Supplier<Q> compile(Binding<Q> binding) {
267         Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
268         if (binding.getScope() != null) {
269             Scope scope = scopes.entrySet().stream()
270                     .filter(e -> e.getKey().isInstance(binding.getScope()))
271                     .map(Map.Entry::getValue)
272                     .findFirst()
273                     .orElseThrow(() -> new DIException("Scope not bound for annotation "
274                             + binding.getScope().annotationType()))
275                     .get();
276             compiled = scope.scope((Key<Q>) binding.getOriginalKey(), compiled);
277         }
278         return compiled;
279     }
280 
281     protected void doBindImplicit(Key<?> key, Binding<?> binding) {
282         if (binding != null) {
283             // For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
284             Object qualifier = key.getQualifier();
285             Class<?> type = key.getRawType();
286             Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
287             for (Type t : Types.getAllSuperTypes(type)) {
288                 if (types == null || types.contains(Types.getRawType(t))) {
289                     bind(Key.ofType(t, qualifier), binding);
290                     if (qualifier != null) {
291                         bind(Key.ofType(t), binding);
292                     }
293                 }
294             }
295         }
296         // Bind inner classes
297         for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
298             boolean hasQualifier = Stream.of(inner.getAnnotations())
299                     .anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
300             if (hasQualifier) {
301                 bindImplicit(inner);
302             }
303         }
304         // Bind inner providers
305         for (Method method : key.getRawType().getDeclaredMethods()) {
306             if (method.isAnnotationPresent(Provides.class)) {
307                 if (method.getTypeParameters().length != 0) {
308                     throw new DIException("Parameterized method are not supported " + method);
309                 }
310                 Object qualifier = ReflectionUtils.qualifierOf(method);
311                 Annotation scope = ReflectionUtils.scopeOf(method);
312                 Type returnType = method.getGenericReturnType();
313                 Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
314                 Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
315                 for (Type t : Types.getAllSuperTypes(returnType)) {
316                     if (types == null || types.contains(Types.getRawType(t))) {
317                         bind(Key.ofType(t, qualifier), bind);
318                         if (qualifier != null) {
319                             bind(Key.ofType(t), bind);
320                         }
321                     }
322                 }
323             }
324         }
325     }
326 
327     private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
328         if (typed != null) {
329             Class<?>[] typesArray = typed.value();
330             if (typesArray == null || typesArray.length == 0) {
331                 Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
332                 types.add(Object.class);
333                 return types;
334             } else {
335                 return new HashSet<>(Arrays.asList(typesArray));
336             }
337         } else {
338             return null;
339         }
340     }
341 
342     protected <K, V, T> Map<K, V> map(Map<K, T> map, Function<T, V> mapper) {
343         return new WrappingMap<>(map, mapper);
344     }
345 
346     private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
347 
348         private final Map<K, T> delegate;
349         private final Function<T, V> mapper;
350 
351         WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
352             this.delegate = delegate;
353             this.mapper = mapper;
354         }
355 
356         @Override
357         public Set<Entry<K, V>> entrySet() {
358             return new AbstractSet<>() {
359                 @Override
360                 public Iterator<Entry<K, V>> iterator() {
361                     Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
362                     return new Iterator<>() {
363                         @Override
364                         public boolean hasNext() {
365                             return it.hasNext();
366                         }
367 
368                         @Override
369                         public Entry<K, V> next() {
370                             Entry<K, T> n = it.next();
371                             return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
372                         }
373                     };
374                 }
375 
376                 @Override
377                 public int size() {
378                     return delegate.size();
379                 }
380             };
381         }
382     }
383 
384     protected <Q, T> List<Q> list(List<T> bindingList, Function<T, Q> mapper) {
385         return new WrappingList<>(bindingList, mapper);
386     }
387 
388     private static class WrappingList<Q, T> extends AbstractList<Q> {
389 
390         private final List<T> delegate;
391         private final Function<T, Q> mapper;
392 
393         WrappingList(List<T> delegate, Function<T, Q> mapper) {
394             this.delegate = delegate;
395             this.mapper = mapper;
396         }
397 
398         @Override
399         public Q get(int index) {
400             return mapper.apply(delegate.get(index));
401         }
402 
403         @Override
404         public int size() {
405             return delegate.size();
406         }
407     }
408 
409     private void checkCyclicDependency(Key<?> key) {
410         Set<Key<?>> stack = resolutionStack.get();
411         if (stack == null) {
412             stack = new LinkedHashSet<>();
413             resolutionStack.set(stack);
414         }
415         if (!stack.add(key)) {
416             throw new DIException("Cyclic dependency detected: "
417                     + stack.stream().map(Key::getDisplayString).collect(Collectors.joining(" -> "))
418                     + " -> "
419                     + key.getDisplayString());
420         }
421     }
422 
423     private void removeFromResolutionStack(Key<?> key) {
424         Set<Key<?>> stack = resolutionStack.get();
425         if (stack != null) {
426             stack.remove(key);
427             if (stack.isEmpty()) {
428                 resolutionStack.remove();
429             }
430         }
431     }
432 
433     private static class SingletonScope implements Scope {
434         Map<Key<?>, java.util.function.Supplier<?>> cache = new ConcurrentHashMap<>();
435 
436         @Nonnull
437         @SuppressWarnings("unchecked")
438         @Override
439         public <T> java.util.function.Supplier<T> scope(
440                 @Nonnull Key<T> key, @Nonnull java.util.function.Supplier<T> unscoped) {
441             return (java.util.function.Supplier<T>)
442                     cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
443                         volatile T instance;
444 
445                         @Override
446                         public T get() {
447                             if (instance == null) {
448                                 synchronized (this) {
449                                     if (instance == null) {
450                                         instance = unscoped.get();
451                                     }
452                                 }
453                             }
454                             return instance;
455                         }
456                     });
457         }
458     }
459 }