1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  package org.apache.maven.plugin.surefire.extensions;
20  
21  import javax.annotation.Nonnull;
22  
23  import java.io.Closeable;
24  import java.io.IOException;
25  import java.net.InetAddress;
26  import java.net.InetSocketAddress;
27  import java.net.SocketOption;
28  import java.net.URI;
29  import java.net.URISyntaxException;
30  import java.nio.Buffer;
31  import java.nio.ByteBuffer;
32  import java.nio.channels.AsynchronousServerSocketChannel;
33  import java.nio.channels.AsynchronousSocketChannel;
34  import java.nio.channels.ReadableByteChannel;
35  import java.nio.channels.WritableByteChannel;
36  import java.util.concurrent.ExecutionException;
37  import java.util.concurrent.ExecutorService;
38  import java.util.concurrent.Executors;
39  import java.util.concurrent.Future;
40  
41  import org.apache.maven.plugin.surefire.booterclient.output.NativeStdOutStreamConsumer;
42  import org.apache.maven.surefire.api.event.Event;
43  import org.apache.maven.surefire.api.fork.ForkNodeArguments;
44  import org.apache.maven.surefire.extensions.CloseableDaemonThread;
45  import org.apache.maven.surefire.extensions.CommandReader;
46  import org.apache.maven.surefire.extensions.EventHandler;
47  import org.apache.maven.surefire.extensions.ForkChannel;
48  import org.apache.maven.surefire.extensions.util.CountDownLauncher;
49  import org.apache.maven.surefire.extensions.util.CountdownCloseable;
50  import org.apache.maven.surefire.extensions.util.LineConsumerThread;
51  
52  import static java.net.StandardSocketOptions.SO_KEEPALIVE;
53  import static java.net.StandardSocketOptions.SO_REUSEADDR;
54  import static java.net.StandardSocketOptions.TCP_NODELAY;
55  import static java.nio.channels.AsynchronousChannelGroup.withThreadPool;
56  import static java.nio.channels.AsynchronousServerSocketChannel.open;
57  import static java.nio.charset.StandardCharsets.US_ASCII;
58  import static org.apache.maven.surefire.api.util.internal.Channels.newBufferedChannel;
59  import static org.apache.maven.surefire.api.util.internal.Channels.newChannel;
60  import static org.apache.maven.surefire.api.util.internal.Channels.newInputStream;
61  import static org.apache.maven.surefire.api.util.internal.Channels.newOutputStream;
62  import static org.apache.maven.surefire.api.util.internal.DaemonThreadFactory.newDaemonThreadFactory;
63  import static org.apache.maven.surefire.shared.lang3.StringUtils.isBlank;
64  import static org.apache.maven.surefire.shared.lang3.StringUtils.isNotBlank;
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  final class SurefireForkChannel extends ForkChannel {
81      private static final ExecutorService THREAD_POOL = Executors.newCachedThreadPool(newDaemonThreadFactory());
82  
83      private final AsynchronousServerSocketChannel server;
84      private final String localHost;
85      private final int localPort;
86      private final String sessionId;
87      private final Bindings bindings = new Bindings(2);
88      private volatile Future<AsynchronousSocketChannel> session;
89      private volatile LineConsumerThread out;
90      private volatile CloseableDaemonThread commandReaderBindings;
91      private volatile CloseableDaemonThread eventHandlerBindings;
92      private volatile EventBindings eventBindings;
93      private volatile CommandBindings commandBindings;
94  
95      SurefireForkChannel(@Nonnull ForkNodeArguments arguments) throws IOException {
96          super(arguments);
97          server = open(withThreadPool(THREAD_POOL));
98          setTrueOptions(SO_REUSEADDR, TCP_NODELAY, SO_KEEPALIVE);
99          InetAddress ip = InetAddress.getLoopbackAddress();
100         server.bind(new InetSocketAddress(ip, 0), 1);
101         InetSocketAddress localAddress = (InetSocketAddress) server.getLocalAddress();
102         localHost = localAddress.getHostString();
103         localPort = localAddress.getPort();
104         sessionId = arguments.getSessionId();
105     }
106 
107     @Override
108     public void tryConnectToClient() {
109         if (session != null) {
110             throw new IllegalStateException("already accepted TCP client connection");
111         }
112         session = server.accept();
113     }
114 
115     @Override
116     public String getForkNodeConnectionString() {
117         try {
118             URI uri = new URI(
119                     "tcp",
120                     null,
121                     localHost,
122                     localPort,
123                     null,
124                     isBlank(sessionId) ? null : "sessionId=" + sessionId,
125                     null);
126             return uri.toASCIIString();
127         } catch (URISyntaxException e) {
128             throw new IllegalStateException(e);
129         }
130     }
131 
132     @Override
133     public int getCountdownCloseablePermits() {
134         return 3;
135     }
136 
137     @Override
138     public void bindCommandReader(@Nonnull CommandReader commands, WritableByteChannel stdIn)
139             throws IOException, InterruptedException {
140         commandBindings = new CommandBindings(commands);
141 
142         bindings.countDown();
143     }
144 
145     @Override
146     public void bindEventHandler(
147             @Nonnull EventHandler<Event> eventHandler,
148             @Nonnull CountdownCloseable countdown,
149             ReadableByteChannel stdOut)
150             throws IOException, InterruptedException {
151         ForkNodeArguments args = getArguments();
152         out = new LineConsumerThread(
153                 "fork-" + args.getForkChannelId() + "-out-thread",
154                 stdOut,
155                 new NativeStdOutStreamConsumer(args.getConsoleLock()),
156                 countdown);
157         out.start();
158 
159         eventBindings = new EventBindings(eventHandler, countdown);
160 
161         bindings.countDown();
162     }
163 
164     @Override
165     public void disable() {
166         if (eventHandlerBindings != null) {
167             eventHandlerBindings.disable();
168         }
169 
170         if (commandReaderBindings != null) {
171             commandReaderBindings.disable();
172         }
173     }
174 
175     @Override
176     public void close() throws IOException {
177         
178         try (Closeable c1 = getChannel();
179                 Closeable c2 = server;
180                 Closeable c3 = out) {
181             
182         } catch (InterruptedException e) {
183             Throwable cause = e.getCause();
184             throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
185         }
186     }
187 
188     private void verifySessionId() throws InterruptedException, IOException {
189         try {
190             ByteBuffer buffer = ByteBuffer.allocate(sessionId.length());
191             int read;
192             do {
193                 read = getChannel().read(buffer).get();
194             } while (read != -1 && buffer.hasRemaining());
195 
196             if (read == -1) {
197                 throw new IOException("Channel closed while verifying the client.");
198             }
199 
200             ((Buffer) buffer).flip();
201             String clientSessionId = new String(buffer.array(), US_ASCII);
202             if (!clientSessionId.equals(sessionId)) {
203                 throw new InvalidSessionIdException(clientSessionId, sessionId);
204             }
205         } catch (ExecutionException e) {
206             Throwable cause = e.getCause();
207             throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
208         }
209     }
210 
211     @SafeVarargs
212     private final void setTrueOptions(SocketOption<Boolean>... options) throws IOException {
213         for (SocketOption<Boolean> option : options) {
214             if (server.supportedOptions().contains(option)) {
215                 server.setOption(option, true);
216             }
217         }
218     }
219 
220     private class EventBindings {
221         private final EventHandler<Event> eventHandler;
222         private final CountdownCloseable countdown;
223 
224         private EventBindings(EventHandler<Event> eventHandler, CountdownCloseable countdown) {
225             this.eventHandler = eventHandler;
226             this.countdown = countdown;
227         }
228 
229         void bindEventHandler(AsynchronousSocketChannel source) {
230             ForkNodeArguments args = getArguments();
231             String threadName = "fork-" + args.getForkChannelId() + "-event-thread";
232             ReadableByteChannel channel = newBufferedChannel(newInputStream(source));
233             eventHandlerBindings = new EventConsumerThread(threadName, channel, eventHandler, countdown, args);
234             eventHandlerBindings.start();
235         }
236     }
237 
238     private class CommandBindings {
239         private final CommandReader commands;
240 
241         private CommandBindings(CommandReader commands) {
242             this.commands = commands;
243         }
244 
245         void bindCommandSender(AsynchronousSocketChannel source) {
246             
247             
248             
249             ForkNodeArguments args = getArguments();
250             WritableByteChannel channel = newChannel(newOutputStream(source));
251             String threadName = "commands-fork-" + args.getForkChannelId();
252             commandReaderBindings = new StreamFeeder(threadName, channel, commands, args.getConsoleLogger());
253             commandReaderBindings.start();
254         }
255     }
256 
257     private class Bindings extends CountDownLauncher {
258         private Bindings(int count) {
259             super(count);
260         }
261 
262         @Override
263         protected void job() throws IOException, InterruptedException {
264             AsynchronousSocketChannel channel = getChannel();
265             if (isNotBlank(sessionId)) {
266                 verifySessionId();
267             }
268             eventBindings.bindEventHandler(channel);
269             commandBindings.bindCommandSender(channel);
270         }
271     }
272 
273     private AsynchronousSocketChannel getChannel() throws InterruptedException, IOException {
274         try {
275             return session == null ? null : session.get();
276         } catch (ExecutionException e) {
277             Throwable cause = e.getCause();
278             throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
279         }
280     }
281 }