001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.eclipse.aether.util.concurrency;
020
021import java.util.concurrent.atomic.AtomicInteger;
022import java.util.concurrent.atomic.AtomicReference;
023import java.util.concurrent.locks.LockSupport;
024
025import static java.util.Objects.requireNonNull;
026
027/**
028 * A utility class to forward any uncaught {@link Error} or {@link RuntimeException} from a {@link Runnable} executed in
029 * a worker thread back to the parent thread. The simplified usage pattern looks like this:
030 *
031 * <pre>
032 * RunnableErrorForwarder errorForwarder = new RunnableErrorForwarder();
033 * for ( Runnable task : tasks )
034 * {
035 *     executor.execute( errorForwarder.wrap( task ) );
036 * }
037 * errorForwarder.await();
038 * </pre>
039 */
040public final class RunnableErrorForwarder {
041
042    private final Thread thread = Thread.currentThread();
043
044    private final AtomicInteger counter = new AtomicInteger();
045
046    private final AtomicReference<Throwable> error = new AtomicReference<>();
047
048    /**
049     * Creates a new error forwarder for worker threads spawned by the current thread.
050     */
051    public RunnableErrorForwarder() {}
052
053    /**
054     * Wraps the specified runnable into an equivalent runnable that will allow forwarding of uncaught errors.
055     *
056     * @param runnable The runnable from which to forward errors, must not be {@code null}.
057     * @return The error-forwarding runnable to eventually execute, never {@code null}.
058     */
059    public Runnable wrap(final Runnable runnable) {
060        requireNonNull(runnable, "runnable cannot be null");
061
062        counter.incrementAndGet();
063
064        return () -> {
065            try {
066                runnable.run();
067            } catch (RuntimeException | Error e) {
068                error.compareAndSet(null, e);
069                throw e;
070            } finally {
071                counter.decrementAndGet();
072                LockSupport.unpark(thread);
073            }
074        };
075    }
076
077    /**
078     * Causes the current thread to wait until all previously {@link #wrap(Runnable) wrapped} runnables have terminated
079     * and potentially re-throws an uncaught {@link RuntimeException} or {@link Error} from any of the runnables. In
080     * case multiple runnables encountered uncaught errors, one error is arbitrarily selected. <em>Note:</em> This
081     * method must be called from the same thread that created this error forwarder instance.
082     */
083    public void await() {
084        awaitTerminationOfAllRunnables();
085
086        Throwable error = this.error.get();
087        if (error != null) {
088            if (error instanceof RuntimeException) {
089                throw (RuntimeException) error;
090            } else if (error instanceof ThreadDeath) {
091                throw new IllegalStateException(error);
092            } else if (error instanceof Error) {
093                throw (Error) error;
094            }
095            throw new IllegalStateException(error);
096        }
097    }
098
099    private void awaitTerminationOfAllRunnables() {
100        if (!thread.equals(Thread.currentThread())) {
101            throw new IllegalStateException(
102                    "wrong caller thread, expected " + thread + " and not " + Thread.currentThread());
103        }
104
105        boolean interrupted = false;
106
107        while (counter.get() > 0) {
108            LockSupport.park();
109
110            if (Thread.interrupted()) {
111                interrupted = true;
112            }
113        }
114
115        if (interrupted) {
116            Thread.currentThread().interrupt();
117        }
118    }
119}