1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.maven.plugin.surefire.extensions;
20
21 import javax.annotation.Nonnull;
22
23 import java.io.Closeable;
24 import java.io.IOException;
25 import java.net.InetAddress;
26 import java.net.InetSocketAddress;
27 import java.net.SocketOption;
28 import java.net.URI;
29 import java.net.URISyntaxException;
30 import java.nio.Buffer;
31 import java.nio.ByteBuffer;
32 import java.nio.channels.AsynchronousServerSocketChannel;
33 import java.nio.channels.AsynchronousSocketChannel;
34 import java.nio.channels.ReadableByteChannel;
35 import java.nio.channels.WritableByteChannel;
36 import java.util.concurrent.ExecutionException;
37 import java.util.concurrent.ExecutorService;
38 import java.util.concurrent.Executors;
39 import java.util.concurrent.Future;
40
41 import org.apache.maven.plugin.surefire.booterclient.output.NativeStdOutStreamConsumer;
42 import org.apache.maven.surefire.api.event.Event;
43 import org.apache.maven.surefire.api.fork.ForkNodeArguments;
44 import org.apache.maven.surefire.extensions.CloseableDaemonThread;
45 import org.apache.maven.surefire.extensions.CommandReader;
46 import org.apache.maven.surefire.extensions.EventHandler;
47 import org.apache.maven.surefire.extensions.ForkChannel;
48 import org.apache.maven.surefire.extensions.util.CountDownLauncher;
49 import org.apache.maven.surefire.extensions.util.CountdownCloseable;
50 import org.apache.maven.surefire.extensions.util.LineConsumerThread;
51
52 import static java.net.StandardSocketOptions.SO_KEEPALIVE;
53 import static java.net.StandardSocketOptions.SO_REUSEADDR;
54 import static java.net.StandardSocketOptions.TCP_NODELAY;
55 import static java.nio.channels.AsynchronousChannelGroup.withThreadPool;
56 import static java.nio.channels.AsynchronousServerSocketChannel.open;
57 import static java.nio.charset.StandardCharsets.US_ASCII;
58 import static org.apache.maven.surefire.api.util.internal.Channels.newBufferedChannel;
59 import static org.apache.maven.surefire.api.util.internal.Channels.newChannel;
60 import static org.apache.maven.surefire.api.util.internal.Channels.newInputStream;
61 import static org.apache.maven.surefire.api.util.internal.Channels.newOutputStream;
62 import static org.apache.maven.surefire.api.util.internal.DaemonThreadFactory.newDaemonThreadFactory;
63 import static org.apache.maven.surefire.shared.lang3.StringUtils.isBlank;
64 import static org.apache.maven.surefire.shared.lang3.StringUtils.isNotBlank;
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 final class SurefireForkChannel extends ForkChannel {
81 private static final ExecutorService THREAD_POOL = Executors.newCachedThreadPool(newDaemonThreadFactory());
82
83 private final AsynchronousServerSocketChannel server;
84 private final String localHost;
85 private final int localPort;
86 private final String sessionId;
87 private final Bindings bindings = new Bindings(2);
88 private volatile Future<AsynchronousSocketChannel> session;
89 private volatile LineConsumerThread out;
90 private volatile CloseableDaemonThread commandReaderBindings;
91 private volatile CloseableDaemonThread eventHandlerBindings;
92 private volatile EventBindings eventBindings;
93 private volatile CommandBindings commandBindings;
94
95 SurefireForkChannel(@Nonnull ForkNodeArguments arguments) throws IOException {
96 super(arguments);
97 server = open(withThreadPool(THREAD_POOL));
98 setTrueOptions(SO_REUSEADDR, TCP_NODELAY, SO_KEEPALIVE);
99 InetAddress ip = InetAddress.getLoopbackAddress();
100 server.bind(new InetSocketAddress(ip, 0), 1);
101 InetSocketAddress localAddress = (InetSocketAddress) server.getLocalAddress();
102 localHost = localAddress.getHostString();
103 localPort = localAddress.getPort();
104 sessionId = arguments.getSessionId();
105 }
106
107 @Override
108 public void tryConnectToClient() {
109 if (session != null) {
110 throw new IllegalStateException("already accepted TCP client connection");
111 }
112 session = server.accept();
113 }
114
115 @Override
116 public String getForkNodeConnectionString() {
117 try {
118 URI uri = new URI(
119 "tcp",
120 null,
121 localHost,
122 localPort,
123 null,
124 isBlank(sessionId) ? null : "sessionId=" + sessionId,
125 null);
126 return uri.toASCIIString();
127 } catch (URISyntaxException e) {
128 throw new IllegalStateException(e);
129 }
130 }
131
132 @Override
133 public int getCountdownCloseablePermits() {
134 return 3;
135 }
136
137 @Override
138 public void bindCommandReader(@Nonnull CommandReader commands, WritableByteChannel stdIn)
139 throws IOException, InterruptedException {
140 commandBindings = new CommandBindings(commands);
141
142 bindings.countDown();
143 }
144
145 @Override
146 public void bindEventHandler(
147 @Nonnull EventHandler<Event> eventHandler,
148 @Nonnull CountdownCloseable countdown,
149 ReadableByteChannel stdOut)
150 throws IOException, InterruptedException {
151 ForkNodeArguments args = getArguments();
152 out = new LineConsumerThread(
153 "fork-" + args.getForkChannelId() + "-out-thread",
154 stdOut,
155 new NativeStdOutStreamConsumer(args.getConsoleLock()),
156 countdown);
157 out.start();
158
159 eventBindings = new EventBindings(eventHandler, countdown);
160
161 bindings.countDown();
162 }
163
164 @Override
165 public void disable() {
166 if (eventHandlerBindings != null) {
167 eventHandlerBindings.disable();
168 }
169
170 if (commandReaderBindings != null) {
171 commandReaderBindings.disable();
172 }
173 }
174
175 @Override
176 public void close() throws IOException {
177
178 try (Closeable c1 = getChannel();
179 Closeable c2 = server;
180 Closeable c3 = out) {
181
182 } catch (InterruptedException e) {
183 Throwable cause = e.getCause();
184 throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
185 }
186 }
187
188 private void verifySessionId() throws InterruptedException, IOException {
189 try {
190 ByteBuffer buffer = ByteBuffer.allocate(sessionId.length());
191 int read;
192 do {
193 read = getChannel().read(buffer).get();
194 } while (read != -1 && buffer.hasRemaining());
195
196 if (read == -1) {
197 throw new IOException("Channel closed while verifying the client.");
198 }
199
200 ((Buffer) buffer).flip();
201 String clientSessionId = new String(buffer.array(), US_ASCII);
202 if (!clientSessionId.equals(sessionId)) {
203 throw new InvalidSessionIdException(clientSessionId, sessionId);
204 }
205 } catch (ExecutionException e) {
206 Throwable cause = e.getCause();
207 throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
208 }
209 }
210
211 @SafeVarargs
212 private final void setTrueOptions(SocketOption<Boolean>... options) throws IOException {
213 for (SocketOption<Boolean> option : options) {
214 if (server.supportedOptions().contains(option)) {
215 server.setOption(option, true);
216 }
217 }
218 }
219
220 private class EventBindings {
221 private final EventHandler<Event> eventHandler;
222 private final CountdownCloseable countdown;
223
224 private EventBindings(EventHandler<Event> eventHandler, CountdownCloseable countdown) {
225 this.eventHandler = eventHandler;
226 this.countdown = countdown;
227 }
228
229 void bindEventHandler(AsynchronousSocketChannel source) {
230 ForkNodeArguments args = getArguments();
231 String threadName = "fork-" + args.getForkChannelId() + "-event-thread";
232 ReadableByteChannel channel = newBufferedChannel(newInputStream(source));
233 eventHandlerBindings = new EventConsumerThread(threadName, channel, eventHandler, countdown, args);
234 eventHandlerBindings.start();
235 }
236 }
237
238 private class CommandBindings {
239 private final CommandReader commands;
240
241 private CommandBindings(CommandReader commands) {
242 this.commands = commands;
243 }
244
245 void bindCommandSender(AsynchronousSocketChannel source) {
246
247
248
249 ForkNodeArguments args = getArguments();
250 WritableByteChannel channel = newChannel(newOutputStream(source));
251 String threadName = "commands-fork-" + args.getForkChannelId();
252 commandReaderBindings = new StreamFeeder(threadName, channel, commands, args.getConsoleLogger());
253 commandReaderBindings.start();
254 }
255 }
256
257 private class Bindings extends CountDownLauncher {
258 private Bindings(int count) {
259 super(count);
260 }
261
262 @Override
263 protected void job() throws IOException, InterruptedException {
264 AsynchronousSocketChannel channel = getChannel();
265 if (isNotBlank(sessionId)) {
266 verifySessionId();
267 }
268 eventBindings.bindEventHandler(channel);
269 commandBindings.bindCommandSender(channel);
270 }
271 }
272
273 private AsynchronousSocketChannel getChannel() throws InterruptedException, IOException {
274 try {
275 return session == null ? null : session.get();
276 } catch (ExecutionException e) {
277 Throwable cause = e.getCause();
278 throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
279 }
280 }
281 }