1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.aether.named.ipc;
20
21 import java.io.DataInputStream;
22 import java.io.DataOutputStream;
23 import java.io.IOException;
24 import java.net.SocketAddress;
25 import java.nio.channels.ByteChannel;
26 import java.nio.channels.Channels;
27 import java.nio.channels.ServerSocketChannel;
28 import java.nio.channels.SocketChannel;
29 import java.util.ArrayList;
30 import java.util.Iterator;
31 import java.util.List;
32 import java.util.Map;
33 import java.util.concurrent.CompletableFuture;
34 import java.util.concurrent.ConcurrentHashMap;
35 import java.util.concurrent.CopyOnWriteArrayList;
36 import java.util.concurrent.TimeUnit;
37 import java.util.concurrent.atomic.AtomicInteger;
38
39
40
41
42
43
44
45 public class IpcServer {
46
47
48
49
50
51
52
53 public static final String SYSTEM_PROP_NO_FORK = "aether.named.ipc.nofork";
54
55 public static final boolean DEFAULT_NO_FORK = false;
56
57
58
59
60
61
62
63
64 public static final String SYSTEM_PROP_IDLE_TIMEOUT = "aether.named.ipc.idleTimeout";
65
66 public static final int DEFAULT_IDLE_TIMEOUT = 300;
67
68
69
70
71
72
73
74
75 public static final String SYSTEM_PROP_FAMILY = "aether.named.ipc.family";
76
77 public static final String DEFAULT_FAMILY = "unix";
78
79
80
81
82
83
84
85
86 public static final String SYSTEM_PROP_NO_NATIVE = "aether.named.ipc.nonative";
87
88 public static final boolean DEFAULT_NO_NATIVE = true;
89
90
91
92
93
94
95
96
97 public static final String SYSTEM_PROP_NATIVE_NAME = "aether.named.ipc.nativeName";
98
99 public static final String DEFAULT_NATIVE_NAME = "ipc-sync";
100
101
102
103
104
105
106
107
108 public static final String SYSTEM_PROP_DEBUG = "aether.named.ipc.debug";
109
110 public static final boolean DEFAULT_DEBUG = false;
111
112 private final ServerSocketChannel serverSocket;
113 private final Map<SocketChannel, Thread> clients = new ConcurrentHashMap<>();
114 private final AtomicInteger counter = new AtomicInteger();
115 private final Map<String, Lock> locks = new ConcurrentHashMap<>();
116 private final Map<String, Context> contexts = new ConcurrentHashMap<>();
117 private static final boolean DEBUG =
118 Boolean.parseBoolean(System.getProperty(SYSTEM_PROP_DEBUG, Boolean.toString(DEFAULT_DEBUG)));
119 private final long idleTimeout;
120 private volatile long lastUsed;
121 private volatile boolean closing;
122
123 public IpcServer(SocketFamily family) throws IOException {
124 serverSocket = family.openServerSocket();
125 long timeout = TimeUnit.SECONDS.toNanos(DEFAULT_IDLE_TIMEOUT);
126 String str = System.getProperty(SYSTEM_PROP_IDLE_TIMEOUT);
127 if (str != null) {
128 try {
129 TimeUnit unit = TimeUnit.SECONDS;
130 if (str.endsWith("ms")) {
131 unit = TimeUnit.MILLISECONDS;
132 str = str.substring(0, str.length() - 2);
133 }
134 long dur = Long.parseLong(str);
135 timeout = unit.toNanos(dur);
136 } catch (NumberFormatException e) {
137 error("Property " + SYSTEM_PROP_IDLE_TIMEOUT + " specified with invalid value: " + str, e);
138 }
139 }
140 idleTimeout = timeout;
141 }
142
143 public static void main(String[] args) throws Exception {
144
145
146
147
148
149
150
151
152 try {
153 sun.misc.Signal.handle(new sun.misc.Signal("INT"), sun.misc.SignalHandler.SIG_IGN);
154 if (!IpcClient.IS_WINDOWS) {
155 sun.misc.Signal.handle(new sun.misc.Signal("TSTP"), sun.misc.SignalHandler.SIG_IGN);
156 }
157 } catch (Throwable t) {
158 error("Unable to ignore INT and TSTP signals", t);
159 }
160
161 String family = args[0];
162 String tmpAddress = args[1];
163 String rand = args[2];
164
165 runServer(SocketFamily.valueOf(family), tmpAddress, rand);
166 }
167
168 static IpcServer runServer(SocketFamily family, String tmpAddress, String rand) throws IOException {
169 IpcServer server = new IpcServer(family);
170 run(server::run, false);
171 String address = SocketFamily.toString(server.getLocalAddress());
172 SocketAddress socketAddress = SocketFamily.fromString(tmpAddress);
173 try (SocketChannel socket = SocketChannel.open(socketAddress)) {
174 try (DataOutputStream dos = new DataOutputStream(Channels.newOutputStream(socket))) {
175 dos.writeUTF(rand);
176 dos.writeUTF(address);
177 dos.flush();
178 }
179 }
180
181 return server;
182 }
183
184 private static void debug(String msg, Object... args) {
185 if (DEBUG) {
186 System.out.printf("[ipc] [debug] " + msg + "\n", args);
187 }
188 }
189
190 private static void info(String msg, Object... args) {
191 System.out.printf("[ipc] [info] " + msg + "\n", args);
192 }
193
194 private static void error(String msg, Throwable t) {
195 System.out.println("[ipc] [error] " + msg);
196 t.printStackTrace(System.out);
197 }
198
199 private static void run(Runnable runnable, boolean daemon) {
200 Thread thread = new Thread(runnable);
201 if (daemon) {
202 thread.setDaemon(true);
203 }
204 thread.start();
205 }
206
207 public SocketAddress getLocalAddress() throws IOException {
208 return serverSocket.getLocalAddress();
209 }
210
211 public void run() {
212 try {
213 info("IpcServer started at %s", getLocalAddress().toString());
214 use();
215 run(this::expirationCheck, true);
216 while (!closing) {
217 SocketChannel socket = this.serverSocket.accept();
218 run(() -> client(socket), false);
219 }
220 } catch (Throwable t) {
221 if (!closing) {
222 error("Error running sync server loop", t);
223 }
224 }
225 }
226
227 private void client(SocketChannel socket) {
228 int c;
229 synchronized (clients) {
230 clients.put(socket, Thread.currentThread());
231 c = clients.size();
232 }
233 info("New client connected (%d connected)", c);
234 use();
235 Map<String, Context> clientContexts = new ConcurrentHashMap<>();
236 try {
237 ByteChannel wrapper = new ByteChannelWrapper(socket);
238 DataInputStream input = new DataInputStream(Channels.newInputStream(wrapper));
239 DataOutputStream output = new DataOutputStream(Channels.newOutputStream(wrapper));
240 while (!closing) {
241 int requestId = input.readInt();
242 int sz = input.readInt();
243 List<String> request = new ArrayList<>(sz);
244 for (int i = 0; i < sz; i++) {
245 request.add(input.readUTF());
246 }
247 if (request.isEmpty()) {
248 throw new IOException("Received invalid request");
249 }
250 use();
251 String contextId;
252 Context context;
253 String command = request.remove(0);
254 switch (command) {
255 case IpcMessages.REQUEST_CONTEXT:
256 if (request.size() != 1) {
257 throw new IOException("Expected one argument for " + command + " but got " + request);
258 }
259 boolean shared = Boolean.parseBoolean(request.remove(0));
260 context = new Context(shared);
261 contexts.put(context.id, context);
262 clientContexts.put(context.id, context);
263 synchronized (output) {
264 debug("Created context %s", context.id);
265 output.writeInt(requestId);
266 output.writeInt(2);
267 output.writeUTF(IpcMessages.RESPONSE_CONTEXT);
268 output.writeUTF(context.id);
269 output.flush();
270 }
271 break;
272 case IpcMessages.REQUEST_ACQUIRE:
273 if (request.isEmpty()) {
274 throw new IOException(
275 "Expected at least one argument for " + command + " but got " + request);
276 }
277 contextId = request.remove(0);
278 context = contexts.get(contextId);
279 if (context == null) {
280 throw new IOException(
281 "Unknown context: " + contextId + ". Known contexts = " + contexts.keySet());
282 }
283 context.lock(request).thenRun(() -> {
284 try {
285 synchronized (output) {
286 debug("Locking in context %s", context.id);
287 output.writeInt(requestId);
288 output.writeInt(1);
289 output.writeUTF(IpcMessages.RESPONSE_ACQUIRE);
290 output.flush();
291 }
292 } catch (IOException e) {
293 try {
294 socket.close();
295 } catch (IOException ioException) {
296 e.addSuppressed(ioException);
297 }
298 error("Error writing lock response", e);
299 }
300 });
301 break;
302 case IpcMessages.REQUEST_CLOSE:
303 if (request.size() != 1) {
304 throw new IOException("Expected one argument for " + command + " but got " + request);
305 }
306 contextId = request.remove(0);
307 context = contexts.remove(contextId);
308 clientContexts.remove(contextId);
309 if (context == null) {
310 throw new IOException(
311 "Unknown context: " + contextId + ". Known contexts = " + contexts.keySet());
312 }
313 context.unlock();
314 synchronized (output) {
315 debug("Closing context %s", context.id);
316 output.writeInt(requestId);
317 output.writeInt(1);
318 output.writeUTF(IpcMessages.RESPONSE_CLOSE);
319 output.flush();
320 }
321 break;
322 case IpcMessages.REQUEST_STOP:
323 if (!request.isEmpty()) {
324 throw new IOException("Expected zero argument for " + command + " but got " + request);
325 }
326 synchronized (output) {
327 debug("Stopping server");
328 output.writeInt(requestId);
329 output.writeInt(1);
330 output.writeUTF(IpcMessages.RESPONSE_STOP);
331 output.flush();
332 }
333 close();
334 break;
335 default:
336 throw new IOException("Unknown request: " + request.get(0));
337 }
338 }
339 } catch (Throwable t) {
340 if (!closing) {
341 error("Error processing request", t);
342 }
343 } finally {
344 if (!closing) {
345 info("Client disconnecting...");
346 }
347 clientContexts.values().forEach(context -> {
348 contexts.remove(context.id);
349 context.unlock();
350 });
351 try {
352 socket.close();
353 } catch (IOException ioException) {
354
355 }
356 synchronized (clients) {
357 clients.remove(socket);
358 c = clients.size();
359 }
360 if (!closing) {
361 info("%d clients remained", c);
362 }
363 }
364 }
365
366 private void use() {
367 lastUsed = System.nanoTime();
368 }
369
370 private void expirationCheck() {
371 while (true) {
372 long current = System.nanoTime();
373 long left = (lastUsed + idleTimeout) - current;
374 if (clients.isEmpty() && left < 0) {
375 info("IpcServer expired, closing");
376 close();
377 break;
378 } else {
379 try {
380 Thread.sleep(Math.max(1, TimeUnit.NANOSECONDS.toMillis(left)));
381 } catch (InterruptedException e) {
382 info("IpcServer expiration check interrupted, closing");
383 close();
384 break;
385 }
386 }
387 }
388 }
389
390 void close() {
391 closing = true;
392 try {
393 serverSocket.close();
394 } catch (IOException e) {
395 error("Error closing server socket", e);
396 }
397 clients.forEach((s, t) -> {
398 try {
399 s.close();
400 } catch (IOException e) {
401
402 }
403 t.interrupt();
404 });
405 }
406
407 static class Waiter {
408 final Context context;
409 final CompletableFuture<Void> future;
410
411 Waiter(Context context, CompletableFuture<Void> future) {
412 this.context = context;
413 this.future = future;
414 }
415 }
416
417 static class Lock {
418
419 final String key;
420
421 List<Context> holders;
422 List<Waiter> waiters;
423
424 Lock(String key) {
425 this.key = key;
426 }
427
428 public synchronized CompletableFuture<Void> lock(Context context) {
429 if (holders == null) {
430 holders = new ArrayList<>();
431 }
432 if (holders.isEmpty() || holders.get(0).shared && context.shared) {
433 holders.add(context);
434 return CompletableFuture.completedFuture(null);
435 }
436 if (waiters == null) {
437 waiters = new ArrayList<>();
438 }
439
440 CompletableFuture<Void> future = new CompletableFuture<>();
441 waiters.add(new Waiter(context, future));
442 return future;
443 }
444
445 public void unlock(Context context) {
446 List<CompletableFuture<Void>> toComplete;
447 synchronized (this) {
448 toComplete = new ArrayList<>();
449 if (holders.remove(context)) {
450 while (waiters != null
451 && !waiters.isEmpty()
452 && (holders.isEmpty() || holders.get(0).shared && waiters.get(0).context.shared)) {
453 Waiter waiter = waiters.remove(0);
454 holders.add(waiter.context);
455 toComplete.add(waiter.future);
456 }
457 } else if (waiters != null) {
458 for (Iterator<Waiter> it = waiters.iterator(); it.hasNext(); ) {
459 Waiter waiter = it.next();
460 if (waiter.context == context) {
461 it.remove();
462 waiter.future.cancel(false);
463 }
464 }
465 }
466 }
467 toComplete.forEach(f -> f.complete(null));
468 }
469
470 public synchronized boolean isEmpty() {
471 return (holders == null || holders.isEmpty()) && (waiters == null || waiters.isEmpty());
472 }
473 }
474
475 class Context {
476
477 final String id;
478 final boolean shared;
479 final List<String> locks = new CopyOnWriteArrayList<>();
480
481 Context(boolean shared) {
482 this.id = String.format("%08x", counter.incrementAndGet());
483 this.shared = shared;
484 }
485
486 public CompletableFuture<?> lock(List<String> keys) {
487 locks.addAll(keys);
488 CompletableFuture<?>[] futures = keys.stream()
489 .map(k -> IpcServer.this.locks.computeIfAbsent(k, Lock::new))
490 .map(l -> l.lock(this))
491 .toArray(CompletableFuture[]::new);
492 return CompletableFuture.allOf(futures);
493 }
494
495 public void unlock() {
496 locks.stream()
497 .map(k -> IpcServer.this.locks.computeIfAbsent(k, Lock::new))
498 .forEach(l -> {
499 l.unlock(this);
500 IpcServer.this.locks.compute(l.key, (k, v) -> (v == l && v.isEmpty()) ? null : v);
501 });
502 }
503 }
504 }