View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one
3    * or more contributor license agreements.  See the NOTICE file
4    * distributed with this work for additional information
5    * regarding copyright ownership.  The ASF licenses this file
6    * to you under the Apache License, Version 2.0 (the
7    * "License"); you may not use this file except in compliance
8    * with the License.  You may obtain a copy of the License at
9    *
10   *   http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing,
13   * software distributed under the License is distributed on an
14   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   * KIND, either express or implied.  See the License for the
16   * specific language governing permissions and limitations
17   * under the License.
18   */
19  package org.eclipse.aether.util.concurrency;
20  
21  import java.util.concurrent.atomic.AtomicInteger;
22  import java.util.concurrent.atomic.AtomicReference;
23  import java.util.concurrent.locks.LockSupport;
24  
25  import static java.util.Objects.requireNonNull;
26  
27  /**
28   * A utility class to forward any uncaught {@link Error} or {@link RuntimeException} from a {@link Runnable} executed in
29   * a worker thread back to the parent thread. The simplified usage pattern looks like this:
30   *
31   * <pre>
32   * RunnableErrorForwarder errorForwarder = new RunnableErrorForwarder();
33   * for ( Runnable task : tasks )
34   * {
35   *     executor.execute( errorForwarder.wrap( task ) );
36   * }
37   * errorForwarder.await();
38   * </pre>
39   */
40  public final class RunnableErrorForwarder {
41  
42      private final Thread thread = Thread.currentThread();
43  
44      private final AtomicInteger counter = new AtomicInteger();
45  
46      private final AtomicReference<Throwable> error = new AtomicReference<>();
47  
48      /**
49       * Creates a new error forwarder for worker threads spawned by the current thread.
50       */
51      public RunnableErrorForwarder() {}
52  
53      /**
54       * Wraps the specified runnable into an equivalent runnable that will allow forwarding of uncaught errors.
55       *
56       * @param runnable The runnable from which to forward errors, must not be {@code null}.
57       * @return The error-forwarding runnable to eventually execute, never {@code null}.
58       */
59      public Runnable wrap(final Runnable runnable) {
60          requireNonNull(runnable, "runnable cannot be null");
61  
62          counter.incrementAndGet();
63  
64          return () -> {
65              try {
66                  runnable.run();
67              } catch (RuntimeException | Error e) {
68                  error.compareAndSet(null, e);
69                  throw e;
70              } finally {
71                  counter.decrementAndGet();
72                  LockSupport.unpark(thread);
73              }
74          };
75      }
76  
77      /**
78       * Causes the current thread to wait until all previously {@link #wrap(Runnable) wrapped} runnables have terminated
79       * and potentially re-throws an uncaught {@link RuntimeException} or {@link Error} from any of the runnables. In
80       * case multiple runnables encountered uncaught errors, one error is arbitrarily selected. <em>Note:</em> This
81       * method must be called from the same thread that created this error forwarder instance.
82       */
83      public void await() {
84          awaitTerminationOfAllRunnables();
85  
86          Throwable error = this.error.get();
87          if (error != null) {
88              if (error instanceof RuntimeException) {
89                  throw (RuntimeException) error;
90              } else if (error instanceof ThreadDeath) {
91                  throw new IllegalStateException(error);
92              } else if (error instanceof Error) {
93                  throw (Error) error;
94              }
95              throw new IllegalStateException(error);
96          }
97      }
98  
99      private void awaitTerminationOfAllRunnables() {
100         if (!thread.equals(Thread.currentThread())) {
101             throw new IllegalStateException(
102                     "wrong caller thread, expected " + thread + " and not " + Thread.currentThread());
103         }
104 
105         boolean interrupted = false;
106 
107         while (counter.get() > 0) {
108             LockSupport.park();
109 
110             if (Thread.interrupted()) {
111                 interrupted = true;
112             }
113         }
114 
115         if (interrupted) {
116             Thread.currentThread().interrupt();
117         }
118     }
119 }