1 package org.apache.maven.plugin.surefire.extensions;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 import org.apache.maven.plugin.surefire.booterclient.output.NativeStdOutStreamConsumer;
23 import org.apache.maven.surefire.api.event.Event;
24 import org.apache.maven.surefire.api.fork.ForkNodeArguments;
25 import org.apache.maven.surefire.extensions.CloseableDaemonThread;
26 import org.apache.maven.surefire.extensions.CommandReader;
27 import org.apache.maven.surefire.extensions.EventHandler;
28 import org.apache.maven.surefire.extensions.ForkChannel;
29 import org.apache.maven.surefire.extensions.util.CountDownLauncher;
30 import org.apache.maven.surefire.extensions.util.CountdownCloseable;
31 import org.apache.maven.surefire.extensions.util.LineConsumerThread;
32
33 import javax.annotation.Nonnull;
34 import java.io.Closeable;
35 import java.io.IOException;
36 import java.net.InetAddress;
37 import java.net.InetSocketAddress;
38 import java.net.SocketOption;
39 import java.nio.Buffer;
40 import java.nio.ByteBuffer;
41 import java.nio.channels.AsynchronousServerSocketChannel;
42 import java.nio.channels.AsynchronousSocketChannel;
43 import java.nio.channels.ReadableByteChannel;
44 import java.nio.channels.WritableByteChannel;
45 import java.util.concurrent.ExecutionException;
46 import java.util.concurrent.ExecutorService;
47 import java.util.concurrent.Executors;
48 import java.util.concurrent.Future;
49
50 import static java.net.StandardSocketOptions.SO_KEEPALIVE;
51 import static java.net.StandardSocketOptions.SO_REUSEADDR;
52 import static java.net.StandardSocketOptions.TCP_NODELAY;
53 import static java.nio.channels.AsynchronousChannelGroup.withThreadPool;
54 import static java.nio.channels.AsynchronousServerSocketChannel.open;
55 import static java.nio.charset.StandardCharsets.US_ASCII;
56 import static org.apache.maven.surefire.api.util.internal.Channels.newBufferedChannel;
57 import static org.apache.maven.surefire.api.util.internal.Channels.newChannel;
58 import static org.apache.maven.surefire.api.util.internal.Channels.newInputStream;
59 import static org.apache.maven.surefire.api.util.internal.Channels.newOutputStream;
60 import static org.apache.maven.surefire.api.util.internal.DaemonThreadFactory.newDaemonThreadFactory;
61 import static org.apache.maven.surefire.shared.lang3.StringUtils.isBlank;
62 import static org.apache.maven.surefire.shared.lang3.StringUtils.isNotBlank;
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 final class SurefireForkChannel extends ForkChannel
79 {
80 private static final ExecutorService THREAD_POOL = Executors.newCachedThreadPool( newDaemonThreadFactory() );
81
82 private final AsynchronousServerSocketChannel server;
83 private final String localHost;
84 private final int localPort;
85 private final String sessionId;
86 private final Bindings bindings = new Bindings( 2 );
87 private volatile Future<AsynchronousSocketChannel> session;
88 private volatile LineConsumerThread out;
89 private volatile CloseableDaemonThread commandReaderBindings;
90 private volatile CloseableDaemonThread eventHandlerBindings;
91 private volatile EventBindings eventBindings;
92 private volatile CommandBindings commandBindings;
93
94 SurefireForkChannel( @Nonnull ForkNodeArguments arguments ) throws IOException
95 {
96 super( arguments );
97 server = open( withThreadPool( THREAD_POOL ) );
98 setTrueOptions( SO_REUSEADDR, TCP_NODELAY, SO_KEEPALIVE );
99 InetAddress ip = InetAddress.getLoopbackAddress();
100 server.bind( new InetSocketAddress( ip, 0 ), 1 );
101 InetSocketAddress localAddress = (InetSocketAddress) server.getLocalAddress();
102 localHost = localAddress.getHostString();
103 localPort = localAddress.getPort();
104 sessionId = arguments.getSessionId();
105 }
106
107 @Override
108 public void tryConnectToClient()
109 {
110 if ( session != null )
111 {
112 throw new IllegalStateException( "already accepted TCP client connection" );
113 }
114 session = server.accept();
115 }
116
117 @Override
118 public String getForkNodeConnectionString()
119 {
120 return "tcp://" + localHost + ":" + localPort + ( isBlank( sessionId ) ? "" : "?sessionId=" + sessionId );
121 }
122
123 @Override
124 public int getCountdownCloseablePermits()
125 {
126 return 3;
127 }
128
129 @Override
130 public void bindCommandReader( @Nonnull CommandReader commands, WritableByteChannel stdIn )
131 throws IOException, InterruptedException
132 {
133 commandBindings = new CommandBindings( commands );
134
135 bindings.countDown();
136 }
137
138 @Override
139 public void bindEventHandler( @Nonnull EventHandler<Event> eventHandler,
140 @Nonnull CountdownCloseable countdown,
141 ReadableByteChannel stdOut )
142 throws IOException, InterruptedException
143 {
144 ForkNodeArguments args = getArguments();
145 out = new LineConsumerThread( "fork-" + args.getForkChannelId() + "-out-thread", stdOut,
146 new NativeStdOutStreamConsumer( args.getConsoleLock() ), countdown );
147 out.start();
148
149 eventBindings = new EventBindings( eventHandler, countdown );
150
151 bindings.countDown();
152 }
153
154 @Override
155 public void disable()
156 {
157 if ( eventHandlerBindings != null )
158 {
159 eventHandlerBindings.disable();
160 }
161
162 if ( commandReaderBindings != null )
163 {
164 commandReaderBindings.disable();
165 }
166 }
167
168 @Override
169 public void close() throws IOException
170 {
171
172 try ( Closeable c1 = getChannel(); Closeable c2 = server; Closeable c3 = out )
173 {
174
175 }
176 catch ( InterruptedException e )
177 {
178 Throwable cause = e.getCause();
179 throw cause instanceof IOException ? (IOException) cause : new IOException( cause );
180 }
181 }
182
183 private void verifySessionId() throws InterruptedException, IOException
184 {
185 try
186 {
187 ByteBuffer buffer = ByteBuffer.allocate( sessionId.length() );
188 int read;
189 do
190 {
191 read = getChannel().read( buffer ).get();
192 } while ( read != -1 && buffer.hasRemaining() );
193
194 if ( read == -1 )
195 {
196 throw new IOException( "Channel closed while verifying the client." );
197 }
198
199 ( (Buffer) buffer ).flip();
200 String clientSessionId = new String( buffer.array(), US_ASCII );
201 if ( !clientSessionId.equals( sessionId ) )
202 {
203 throw new InvalidSessionIdException( clientSessionId, sessionId );
204 }
205 }
206 catch ( ExecutionException e )
207 {
208 Throwable cause = e.getCause();
209 throw cause instanceof IOException ? (IOException) cause : new IOException( cause );
210 }
211 }
212
213 @SafeVarargs
214 private final void setTrueOptions( SocketOption<Boolean>... options )
215 throws IOException
216 {
217 for ( SocketOption<Boolean> option : options )
218 {
219 if ( server.supportedOptions().contains( option ) )
220 {
221 server.setOption( option, true );
222 }
223 }
224 }
225
226 private class EventBindings
227 {
228 private final EventHandler<Event> eventHandler;
229 private final CountdownCloseable countdown;
230
231 private EventBindings( EventHandler<Event> eventHandler, CountdownCloseable countdown )
232 {
233 this.eventHandler = eventHandler;
234 this.countdown = countdown;
235 }
236
237 void bindEventHandler( AsynchronousSocketChannel source )
238 {
239 ForkNodeArguments args = getArguments();
240 String threadName = "fork-" + args.getForkChannelId() + "-event-thread";
241 ReadableByteChannel channel = newBufferedChannel( newInputStream( source ) );
242 eventHandlerBindings = new EventConsumerThread( threadName, channel, eventHandler, countdown, args );
243 eventHandlerBindings.start();
244 }
245 }
246
247 private class CommandBindings
248 {
249 private final CommandReader commands;
250
251 private CommandBindings( CommandReader commands )
252 {
253 this.commands = commands;
254 }
255
256 void bindCommandSender( AsynchronousSocketChannel source )
257 {
258
259
260
261 ForkNodeArguments args = getArguments();
262 WritableByteChannel channel = newChannel( newOutputStream( source ) );
263 String threadName = "commands-fork-" + args.getForkChannelId();
264 commandReaderBindings = new StreamFeeder( threadName, channel, commands, args.getConsoleLogger() );
265 commandReaderBindings.start();
266 }
267 }
268
269 private class Bindings extends CountDownLauncher
270 {
271 private Bindings( int count )
272 {
273 super( count );
274 }
275
276 @Override
277 protected void job() throws IOException, InterruptedException
278 {
279 AsynchronousSocketChannel channel = getChannel();
280 if ( isNotBlank( sessionId ) )
281 {
282 verifySessionId();
283 }
284 eventBindings.bindEventHandler( channel );
285 commandBindings.bindCommandSender( channel );
286 }
287 }
288
289 private AsynchronousSocketChannel getChannel()
290 throws InterruptedException, IOException
291 {
292 try
293 {
294 return session == null ? null : session.get();
295 }
296 catch ( ExecutionException e )
297 {
298 Throwable cause = e.getCause();
299 throw cause instanceof IOException ? (IOException) cause : new IOException( cause );
300 }
301 }
302 }