1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.maven.plugin.surefire.extensions;
20
21 import javax.annotation.Nonnull;
22
23 import java.io.Closeable;
24 import java.io.IOException;
25 import java.net.InetAddress;
26 import java.net.InetSocketAddress;
27 import java.net.SocketOption;
28 import java.nio.Buffer;
29 import java.nio.ByteBuffer;
30 import java.nio.channels.AsynchronousServerSocketChannel;
31 import java.nio.channels.AsynchronousSocketChannel;
32 import java.nio.channels.ReadableByteChannel;
33 import java.nio.channels.WritableByteChannel;
34 import java.util.concurrent.ExecutionException;
35 import java.util.concurrent.ExecutorService;
36 import java.util.concurrent.Executors;
37 import java.util.concurrent.Future;
38
39 import org.apache.maven.plugin.surefire.booterclient.output.NativeStdOutStreamConsumer;
40 import org.apache.maven.surefire.api.event.Event;
41 import org.apache.maven.surefire.api.fork.ForkNodeArguments;
42 import org.apache.maven.surefire.extensions.CloseableDaemonThread;
43 import org.apache.maven.surefire.extensions.CommandReader;
44 import org.apache.maven.surefire.extensions.EventHandler;
45 import org.apache.maven.surefire.extensions.ForkChannel;
46 import org.apache.maven.surefire.extensions.util.CountDownLauncher;
47 import org.apache.maven.surefire.extensions.util.CountdownCloseable;
48 import org.apache.maven.surefire.extensions.util.LineConsumerThread;
49
50 import static java.net.StandardSocketOptions.SO_KEEPALIVE;
51 import static java.net.StandardSocketOptions.SO_REUSEADDR;
52 import static java.net.StandardSocketOptions.TCP_NODELAY;
53 import static java.nio.channels.AsynchronousChannelGroup.withThreadPool;
54 import static java.nio.channels.AsynchronousServerSocketChannel.open;
55 import static java.nio.charset.StandardCharsets.US_ASCII;
56 import static org.apache.maven.surefire.api.util.internal.Channels.newBufferedChannel;
57 import static org.apache.maven.surefire.api.util.internal.Channels.newChannel;
58 import static org.apache.maven.surefire.api.util.internal.Channels.newInputStream;
59 import static org.apache.maven.surefire.api.util.internal.Channels.newOutputStream;
60 import static org.apache.maven.surefire.api.util.internal.DaemonThreadFactory.newDaemonThreadFactory;
61 import static org.apache.maven.surefire.shared.lang3.StringUtils.isBlank;
62 import static org.apache.maven.surefire.shared.lang3.StringUtils.isNotBlank;
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 final class SurefireForkChannel extends ForkChannel {
79 private static final ExecutorService THREAD_POOL = Executors.newCachedThreadPool(newDaemonThreadFactory());
80
81 private final AsynchronousServerSocketChannel server;
82 private final String localHost;
83 private final int localPort;
84 private final String sessionId;
85 private final Bindings bindings = new Bindings(2);
86 private volatile Future<AsynchronousSocketChannel> session;
87 private volatile LineConsumerThread out;
88 private volatile CloseableDaemonThread commandReaderBindings;
89 private volatile CloseableDaemonThread eventHandlerBindings;
90 private volatile EventBindings eventBindings;
91 private volatile CommandBindings commandBindings;
92
93 SurefireForkChannel(@Nonnull ForkNodeArguments arguments) throws IOException {
94 super(arguments);
95 server = open(withThreadPool(THREAD_POOL));
96 setTrueOptions(SO_REUSEADDR, TCP_NODELAY, SO_KEEPALIVE);
97 InetAddress ip = InetAddress.getLoopbackAddress();
98 server.bind(new InetSocketAddress(ip, 0), 1);
99 InetSocketAddress localAddress = (InetSocketAddress) server.getLocalAddress();
100 localHost = localAddress.getHostString();
101 localPort = localAddress.getPort();
102 sessionId = arguments.getSessionId();
103 }
104
105 @Override
106 public void tryConnectToClient() {
107 if (session != null) {
108 throw new IllegalStateException("already accepted TCP client connection");
109 }
110 session = server.accept();
111 }
112
113 @Override
114 public String getForkNodeConnectionString() {
115 return "tcp://" + localHost + ":" + localPort + (isBlank(sessionId) ? "" : "?sessionId=" + sessionId);
116 }
117
118 @Override
119 public int getCountdownCloseablePermits() {
120 return 3;
121 }
122
123 @Override
124 public void bindCommandReader(@Nonnull CommandReader commands, WritableByteChannel stdIn)
125 throws IOException, InterruptedException {
126 commandBindings = new CommandBindings(commands);
127
128 bindings.countDown();
129 }
130
131 @Override
132 public void bindEventHandler(
133 @Nonnull EventHandler<Event> eventHandler,
134 @Nonnull CountdownCloseable countdown,
135 ReadableByteChannel stdOut)
136 throws IOException, InterruptedException {
137 ForkNodeArguments args = getArguments();
138 out = new LineConsumerThread(
139 "fork-" + args.getForkChannelId() + "-out-thread",
140 stdOut,
141 new NativeStdOutStreamConsumer(args.getConsoleLock()),
142 countdown);
143 out.start();
144
145 eventBindings = new EventBindings(eventHandler, countdown);
146
147 bindings.countDown();
148 }
149
150 @Override
151 public void disable() {
152 if (eventHandlerBindings != null) {
153 eventHandlerBindings.disable();
154 }
155
156 if (commandReaderBindings != null) {
157 commandReaderBindings.disable();
158 }
159 }
160
161 @Override
162 public void close() throws IOException {
163
164 try (Closeable c1 = getChannel();
165 Closeable c2 = server;
166 Closeable c3 = out) {
167
168 } catch (InterruptedException e) {
169 Throwable cause = e.getCause();
170 throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
171 }
172 }
173
174 private void verifySessionId() throws InterruptedException, IOException {
175 try {
176 ByteBuffer buffer = ByteBuffer.allocate(sessionId.length());
177 int read;
178 do {
179 read = getChannel().read(buffer).get();
180 } while (read != -1 && buffer.hasRemaining());
181
182 if (read == -1) {
183 throw new IOException("Channel closed while verifying the client.");
184 }
185
186 ((Buffer) buffer).flip();
187 String clientSessionId = new String(buffer.array(), US_ASCII);
188 if (!clientSessionId.equals(sessionId)) {
189 throw new InvalidSessionIdException(clientSessionId, sessionId);
190 }
191 } catch (ExecutionException e) {
192 Throwable cause = e.getCause();
193 throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
194 }
195 }
196
197 @SafeVarargs
198 private final void setTrueOptions(SocketOption<Boolean>... options) throws IOException {
199 for (SocketOption<Boolean> option : options) {
200 if (server.supportedOptions().contains(option)) {
201 server.setOption(option, true);
202 }
203 }
204 }
205
206 private class EventBindings {
207 private final EventHandler<Event> eventHandler;
208 private final CountdownCloseable countdown;
209
210 private EventBindings(EventHandler<Event> eventHandler, CountdownCloseable countdown) {
211 this.eventHandler = eventHandler;
212 this.countdown = countdown;
213 }
214
215 void bindEventHandler(AsynchronousSocketChannel source) {
216 ForkNodeArguments args = getArguments();
217 String threadName = "fork-" + args.getForkChannelId() + "-event-thread";
218 ReadableByteChannel channel = newBufferedChannel(newInputStream(source));
219 eventHandlerBindings = new EventConsumerThread(threadName, channel, eventHandler, countdown, args);
220 eventHandlerBindings.start();
221 }
222 }
223
224 private class CommandBindings {
225 private final CommandReader commands;
226
227 private CommandBindings(CommandReader commands) {
228 this.commands = commands;
229 }
230
231 void bindCommandSender(AsynchronousSocketChannel source) {
232
233
234
235 ForkNodeArguments args = getArguments();
236 WritableByteChannel channel = newChannel(newOutputStream(source));
237 String threadName = "commands-fork-" + args.getForkChannelId();
238 commandReaderBindings = new StreamFeeder(threadName, channel, commands, args.getConsoleLogger());
239 commandReaderBindings.start();
240 }
241 }
242
243 private class Bindings extends CountDownLauncher {
244 private Bindings(int count) {
245 super(count);
246 }
247
248 @Override
249 protected void job() throws IOException, InterruptedException {
250 AsynchronousSocketChannel channel = getChannel();
251 if (isNotBlank(sessionId)) {
252 verifySessionId();
253 }
254 eventBindings.bindEventHandler(channel);
255 commandBindings.bindCommandSender(channel);
256 }
257 }
258
259 private AsynchronousSocketChannel getChannel() throws InterruptedException, IOException {
260 try {
261 return session == null ? null : session.get();
262 } catch (ExecutionException e) {
263 Throwable cause = e.getCause();
264 throw cause instanceof IOException ? (IOException) cause : new IOException(cause);
265 }
266 }
267 }