1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.maven.surefire.api.util.internal;
20
21 import java.io.BufferedInputStream;
22 import java.io.BufferedOutputStream;
23 import java.io.IOException;
24 import java.io.InputStream;
25 import java.io.OutputStream;
26 import java.net.InetAddress;
27 import java.net.InetSocketAddress;
28 import java.net.SocketOption;
29 import java.nio.channels.AsynchronousChannelGroup;
30 import java.nio.channels.AsynchronousServerSocketChannel;
31 import java.nio.channels.AsynchronousSocketChannel;
32 import java.nio.charset.StandardCharsets;
33 import java.util.ArrayList;
34 import java.util.List;
35 import java.util.concurrent.CountDownLatch;
36 import java.util.concurrent.ExecutorService;
37 import java.util.concurrent.Executors;
38 import java.util.concurrent.Future;
39 import java.util.concurrent.ThreadFactory;
40 import java.util.concurrent.ThreadPoolExecutor;
41 import java.util.concurrent.TimeUnit;
42 import java.util.concurrent.atomic.AtomicLong;
43
44 import org.junit.Ignore;
45 import org.junit.Test;
46
47 import static java.net.StandardSocketOptions.SO_KEEPALIVE;
48 import static java.net.StandardSocketOptions.SO_REUSEADDR;
49 import static java.net.StandardSocketOptions.TCP_NODELAY;
50 import static org.apache.maven.surefire.api.util.internal.Channels.newInputStream;
51 import static org.apache.maven.surefire.api.util.internal.Channels.newOutputStream;
52 import static org.assertj.core.api.Assertions.assertThat;
53
54
55
56
57 @SuppressWarnings("checkstyle:magicnumber")
58 @Ignore("Can be flaky on slow machine")
59 public class AsyncSocketTest {
60 private static final String LONG_STRING =
61 "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789";
62
63 private final CountDownLatch barrier = new CountDownLatch(1);
64 private final AtomicLong writeTime = new AtomicLong();
65 private final AtomicLong readTime = new AtomicLong();
66
67 private volatile InetSocketAddress address;
68
69 @Test(timeout = 10_000L)
70 public void test() throws Exception {
71 int forks = 2;
72 ThreadFactory factory = DaemonThreadFactory.newDaemonThreadFactory();
73 ExecutorService executorService = Executors.newCachedThreadPool(factory);
74 if (executorService instanceof ThreadPoolExecutor) {
75 ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executorService;
76 threadPoolExecutor.setCorePoolSize(
77 Math.min(forks, Runtime.getRuntime().availableProcessors()));
78 threadPoolExecutor.prestartCoreThread();
79 }
80 AsynchronousChannelGroup group = AsynchronousChannelGroup.withThreadPool(executorService);
81 AsynchronousServerSocketChannel server = AsynchronousServerSocketChannel.open(group);
82 setTrueOptions(server, SO_REUSEADDR, TCP_NODELAY, SO_KEEPALIVE);
83 InetAddress ip = InetAddress.getLoopbackAddress();
84 server.bind(new InetSocketAddress(ip, 0), 1);
85 address = (InetSocketAddress) server.getLocalAddress();
86
87 System.gc();
88 TimeUnit.SECONDS.sleep(3L);
89
90 Thread tc = new Thread() {
91 @Override
92 public void run() {
93 try {
94 client();
95 } catch (Exception e) {
96 e.printStackTrace();
97 }
98 }
99 };
100 tc.setDaemon(true);
101 tc.start();
102
103 Future<AsynchronousSocketChannel> acceptFuture = server.accept();
104 AsynchronousSocketChannel worker = acceptFuture.get();
105 if (!worker.isOpen()) {
106 throw new IOException("client socket closed");
107 }
108 final InputStream is = newInputStream(worker);
109 final OutputStream os = new BufferedOutputStream(newOutputStream(worker), 64 * 1024);
110
111 Thread tt = new Thread() {
112 public void run() {
113 try {
114 byte[] b = new byte[1024];
115 is.read(b);
116 } catch (Exception e) {
117
118 }
119 }
120 };
121 tt.setName("fork-1-event-thread-");
122 tt.setDaemon(true);
123 tt.start();
124
125 Thread t = new Thread() {
126 @SuppressWarnings("checkstyle:magicnumber")
127 public void run() {
128 try {
129 byte[] data = LONG_STRING.getBytes(StandardCharsets.US_ASCII);
130 long t1 = System.currentTimeMillis();
131 int i = 0;
132 for (; i < 320_000; i++) {
133 os.write(data);
134 long t2 = System.currentTimeMillis();
135 long spent = t2 - t1;
136
137 if (i % 100_000 == 0) {
138 System.out.println("spent " + spent + " ms: " + i);
139 }
140 }
141 os.flush();
142 long spent = System.currentTimeMillis() - t1;
143 writeTime.set(spent);
144 System.out.println("spent " + spent + " ms: " + i);
145 } catch (IOException e) {
146 e.printStackTrace();
147 }
148 }
149 };
150 t.setName("commands-fork-1");
151 t.setDaemon(true);
152 t.start();
153
154 barrier.await();
155 tt.join();
156 t.join();
157 tc.join();
158 worker.close();
159 server.close();
160
161
162
163 assertThat(writeTime.get()).isLessThan(1000L);
164
165
166
167 assertThat(readTime.get()).isLessThan(1000L);
168 }
169
170 @SuppressWarnings("checkstyle:magicnumber")
171 private void client() throws Exception {
172 InetSocketAddress hostAddress = new InetSocketAddress(InetAddress.getLoopbackAddress(), address.getPort());
173 AsynchronousSocketChannel clientSocketChannel = AsynchronousSocketChannel.open();
174 clientSocketChannel.connect(hostAddress).get();
175 InputStream is = new BufferedInputStream(newInputStream(clientSocketChannel), 64 * 1024);
176 List<byte[]> bytes = new ArrayList<>();
177 long t1 = System.currentTimeMillis();
178 for (int i = 0; i < 320_000; i++) {
179 byte[] b = new byte[100];
180 is.read(b);
181 bytes.add(b);
182 }
183 long t2 = System.currentTimeMillis();
184 long spent = t2 - t1;
185 readTime.set(spent);
186 System.out.println("string read: " + new String(bytes.get(bytes.size() - 1)));
187 System.out.println("received within " + spent + " ms");
188 clientSocketChannel.close();
189 barrier.countDown();
190 }
191
192 @SafeVarargs
193 private static void setTrueOptions(AsynchronousServerSocketChannel server, SocketOption<Boolean>... options)
194 throws IOException {
195 for (SocketOption<Boolean> option : options) {
196 if (server.supportedOptions().contains(option)) {
197 server.setOption(option, true);
198 }
199 }
200 }
201 }