001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.eclipse.aether.named.ipc;
020
021import java.io.DataInputStream;
022import java.io.DataOutputStream;
023import java.io.IOException;
024import java.net.SocketAddress;
025import java.nio.channels.ByteChannel;
026import java.nio.channels.Channels;
027import java.nio.channels.ServerSocketChannel;
028import java.nio.channels.SocketChannel;
029import java.util.ArrayList;
030import java.util.HashMap;
031import java.util.Iterator;
032import java.util.List;
033import java.util.Map;
034import java.util.concurrent.CompletableFuture;
035import java.util.concurrent.ConcurrentHashMap;
036import java.util.concurrent.CopyOnWriteArrayList;
037import java.util.concurrent.TimeUnit;
038import java.util.concurrent.atomic.AtomicInteger;
039
040/**
041 * Implementation of the server side.
042 * The server instance is bound to a given maven repository.
043 *
044 * @since 2.0.1
045 */
046public class IpcServer {
047    /**
048     * Should the IPC server not fork? (i.e. for testing purposes)
049     *
050     * @configurationSource {@link System#getProperty(String, String)}
051     * @configurationType {@link java.lang.Boolean}
052     * @configurationDefaultValue {@link #DEFAULT_NO_FORK}
053     */
054    public static final String SYSTEM_PROP_NO_FORK = "aether.named.ipc.nofork";
055
056    public static final boolean DEFAULT_NO_FORK = false;
057
058    /**
059     * IPC idle timeout in seconds. If there is no IPC request during idle time, it will stop.
060     *
061     * @configurationSource {@link System#getProperty(String, String)}
062     * @configurationType {@link java.lang.Integer}
063     * @configurationDefaultValue {@link #DEFAULT_IDLE_TIMEOUT}
064     */
065    public static final String SYSTEM_PROP_IDLE_TIMEOUT = "aether.named.ipc.idleTimeout";
066
067    public static final int DEFAULT_IDLE_TIMEOUT = 60;
068
069    /**
070     * IPC socket family to use.
071     *
072     * @configurationSource {@link System#getProperty(String, String)}
073     * @configurationType {@link java.lang.String}
074     * @configurationDefaultValue {@link #DEFAULT_FAMILY}
075     */
076    public static final String SYSTEM_PROP_FAMILY = "aether.named.ipc.family";
077
078    public static final String DEFAULT_FAMILY = "unix";
079
080    /**
081     * Should the IPC server not use native executable?
082     *
083     * @configurationSource {@link System#getProperty(String, String)}
084     * @configurationType {@link java.lang.Boolean}
085     * @configurationDefaultValue {@link #DEFAULT_NO_NATIVE}
086     */
087    public static final String SYSTEM_PROP_NO_NATIVE = "aether.named.ipc.nonative";
088
089    public static final boolean DEFAULT_NO_NATIVE = true;
090
091    /**
092     * The name if the IPC server native executable (without file extension like ".exe")
093     *
094     * @configurationSource {@link System#getProperty(String, String)}
095     * @configurationType {@link java.lang.String}
096     * @configurationDefaultValue {@link #DEFAULT_NATIVE_NAME}
097     */
098    public static final String SYSTEM_PROP_NATIVE_NAME = "aether.named.ipc.nativeName";
099
100    public static final String DEFAULT_NATIVE_NAME = "ipc-sync";
101
102    /**
103     * Should the IPC server log debug messages? (i.e. for testing purposes)
104     *
105     * @configurationSource {@link System#getProperty(String, String)}
106     * @configurationType {@link java.lang.Boolean}
107     * @configurationDefaultValue {@link #DEFAULT_DEBUG}
108     */
109    public static final String SYSTEM_PROP_DEBUG = "aether.named.ipc.debug";
110
111    public static final boolean DEFAULT_DEBUG = false;
112
113    private final ServerSocketChannel serverSocket;
114    private final Map<SocketChannel, Thread> clients = new HashMap<>();
115    private final AtomicInteger counter = new AtomicInteger();
116    private final Map<String, Lock> locks = new ConcurrentHashMap<>();
117    private final Map<String, Context> contexts = new ConcurrentHashMap<>();
118    private static final boolean DEBUG =
119            Boolean.parseBoolean(System.getProperty(SYSTEM_PROP_DEBUG, Boolean.toString(DEFAULT_DEBUG)));
120    private final long idleTimeout;
121    private volatile long lastUsed;
122    private volatile boolean closing;
123
124    public IpcServer(SocketFamily family) throws IOException {
125        serverSocket = family.openServerSocket();
126        long timeout = TimeUnit.SECONDS.toNanos(DEFAULT_IDLE_TIMEOUT);
127        String str = System.getProperty(SYSTEM_PROP_IDLE_TIMEOUT);
128        if (str != null) {
129            try {
130                TimeUnit unit = TimeUnit.SECONDS;
131                if (str.endsWith("ms")) {
132                    unit = TimeUnit.MILLISECONDS;
133                    str = str.substring(0, str.length() - 2);
134                }
135                long dur = Long.parseLong(str);
136                timeout = unit.toNanos(dur);
137            } catch (NumberFormatException e) {
138                error("Property " + SYSTEM_PROP_IDLE_TIMEOUT + " specified with invalid value: " + str, e);
139            }
140        }
141        idleTimeout = timeout;
142    }
143
144    public static void main(String[] args) throws Exception {
145        // When spawning a new process, the child process is create within
146        // the same process group.  This means that a few signals are sent
147        // to the whole group.  This is the case for SIGINT (Ctrl-C) and
148        // SIGTSTP (Ctrl-Z) which are both sent to all the processed in the
149        // group when initiated from the controlling terminal.
150        // This is only a problem when the client creates the daemon, but
151        // without ignoring those signals, a client being interrupted will
152        // also interrupt and kill the daemon.
153        try {
154            sun.misc.Signal.handle(new sun.misc.Signal("INT"), sun.misc.SignalHandler.SIG_IGN);
155            if (IpcClient.IS_WINDOWS) {
156                sun.misc.Signal.handle(new sun.misc.Signal("TSTP"), sun.misc.SignalHandler.SIG_IGN);
157            }
158        } catch (Throwable t) {
159            error("Unable to ignore INT and TSTP signals", t);
160        }
161
162        String family = args[0];
163        String tmpAddress = args[1];
164        String rand = args[2];
165
166        runServer(SocketFamily.valueOf(family), tmpAddress, rand);
167    }
168
169    static IpcServer runServer(SocketFamily family, String tmpAddress, String rand) throws IOException {
170        IpcServer server = new IpcServer(family);
171        run(server::run, false); // this is one-off
172        String address = SocketFamily.toString(server.getLocalAddress());
173        SocketAddress socketAddress = SocketFamily.fromString(tmpAddress);
174        try (SocketChannel socket = SocketChannel.open(socketAddress)) {
175            try (DataOutputStream dos = new DataOutputStream(Channels.newOutputStream(socket))) {
176                dos.writeUTF(rand);
177                dos.writeUTF(address);
178                dos.flush();
179            }
180        }
181
182        return server;
183    }
184
185    private static void debug(String msg, Object... args) {
186        if (DEBUG) {
187            System.out.printf("[ipc] [debug] " + msg + "\n", args);
188        }
189    }
190
191    private static void info(String msg, Object... args) {
192        System.out.printf("[ipc] [info] " + msg + "\n", args);
193    }
194
195    private static void error(String msg, Throwable t) {
196        System.out.println("[ipc] [error] " + msg);
197        t.printStackTrace(System.out);
198    }
199
200    private static void run(Runnable runnable, boolean daemon) {
201        Thread thread = new Thread(runnable);
202        if (daemon) {
203            thread.setDaemon(true);
204        }
205        thread.start();
206    }
207
208    public SocketAddress getLocalAddress() throws IOException {
209        return serverSocket.getLocalAddress();
210    }
211
212    public void run() {
213        try {
214            info("IpcServer started at %s", getLocalAddress().toString());
215            use();
216            run(this::expirationCheck, true);
217            while (!closing) {
218                SocketChannel socket = this.serverSocket.accept();
219                run(() -> client(socket), false);
220            }
221        } catch (Throwable t) {
222            if (!closing) {
223                error("Error running sync server loop", t);
224            }
225        }
226    }
227
228    private void client(SocketChannel socket) {
229        int c;
230        synchronized (clients) {
231            clients.put(socket, Thread.currentThread());
232            c = clients.size();
233        }
234        info("New client connected (%d connected)", c);
235        use();
236        Map<String, Context> clientContexts = new ConcurrentHashMap<>();
237        try {
238            ByteChannel wrapper = new ByteChannelWrapper(socket);
239            DataInputStream input = new DataInputStream(Channels.newInputStream(wrapper));
240            DataOutputStream output = new DataOutputStream(Channels.newOutputStream(wrapper));
241            while (!closing) {
242                int requestId = input.readInt();
243                int sz = input.readInt();
244                List<String> request = new ArrayList<>(sz);
245                for (int i = 0; i < sz; i++) {
246                    request.add(input.readUTF());
247                }
248                if (request.isEmpty()) {
249                    throw new IOException("Received invalid request");
250                }
251                use();
252                String contextId;
253                Context context;
254                String command = request.remove(0);
255                switch (command) {
256                    case IpcMessages.REQUEST_CONTEXT:
257                        if (request.size() != 1) {
258                            throw new IOException("Expected one argument for " + command + " but got " + request);
259                        }
260                        boolean shared = Boolean.parseBoolean(request.remove(0));
261                        context = new Context(shared);
262                        contexts.put(context.id, context);
263                        clientContexts.put(context.id, context);
264                        synchronized (output) {
265                            debug("Created context %s", context.id);
266                            output.writeInt(requestId);
267                            output.writeInt(2);
268                            output.writeUTF(IpcMessages.RESPONSE_CONTEXT);
269                            output.writeUTF(context.id);
270                            output.flush();
271                        }
272                        break;
273                    case IpcMessages.REQUEST_ACQUIRE:
274                        if (request.size() < 1) {
275                            throw new IOException(
276                                    "Expected at least one argument for " + command + " but got " + request);
277                        }
278                        contextId = request.remove(0);
279                        context = contexts.get(contextId);
280                        if (context == null) {
281                            throw new IOException(
282                                    "Unknown context: " + contextId + ". Known contexts = " + contexts.keySet());
283                        }
284                        context.lock(request).thenRun(() -> {
285                            try {
286                                synchronized (output) {
287                                    debug("Locking in context %s", context.id);
288                                    output.writeInt(requestId);
289                                    output.writeInt(1);
290                                    output.writeUTF(IpcMessages.RESPONSE_ACQUIRE);
291                                    output.flush();
292                                }
293                            } catch (IOException e) {
294                                try {
295                                    socket.close();
296                                } catch (IOException ioException) {
297                                    e.addSuppressed(ioException);
298                                }
299                                error("Error writing lock response", e);
300                            }
301                        });
302                        break;
303                    case IpcMessages.REQUEST_CLOSE:
304                        if (request.size() != 1) {
305                            throw new IOException("Expected one argument for " + command + " but got " + request);
306                        }
307                        contextId = request.remove(0);
308                        context = contexts.remove(contextId);
309                        clientContexts.remove(contextId);
310                        if (context == null) {
311                            throw new IOException(
312                                    "Unknown context: " + contextId + ". Known contexts = " + contexts.keySet());
313                        }
314                        context.unlock();
315                        synchronized (output) {
316                            debug("Closing context %s", context.id);
317                            output.writeInt(requestId);
318                            output.writeInt(1);
319                            output.writeUTF(IpcMessages.RESPONSE_CLOSE);
320                            output.flush();
321                        }
322                        break;
323                    case IpcMessages.REQUEST_STOP:
324                        if (request.size() != 0) {
325                            throw new IOException("Expected zero argument for " + command + " but got " + request);
326                        }
327                        synchronized (output) {
328                            debug("Stopping server");
329                            output.writeInt(requestId);
330                            output.writeInt(1);
331                            output.writeUTF(IpcMessages.RESPONSE_STOP);
332                            output.flush();
333                        }
334                        close();
335                        break;
336                    default:
337                        throw new IOException("Unknown request: " + request.get(0));
338                }
339            }
340        } catch (Throwable t) {
341            if (!closing) {
342                error("Error processing request", t);
343            }
344        } finally {
345            if (!closing) {
346                info("Client disconnecting...");
347            }
348            clientContexts.values().forEach(context -> {
349                contexts.remove(context.id);
350                context.unlock();
351            });
352            try {
353                socket.close();
354            } catch (IOException ioException) {
355                // ignore
356            }
357            synchronized (clients) {
358                clients.remove(socket);
359                c = clients.size();
360            }
361            if (!closing) {
362                info("%d clients left", c);
363            }
364        }
365    }
366
367    private void use() {
368        lastUsed = System.nanoTime();
369    }
370
371    private void expirationCheck() {
372        while (true) {
373            long current = System.nanoTime();
374            long left = (lastUsed + idleTimeout) - current;
375            if (left < 0) {
376                info("IpcServer expired, closing");
377                close();
378                break;
379            } else {
380                try {
381                    Thread.sleep(TimeUnit.NANOSECONDS.toMillis(left));
382                } catch (InterruptedException e) {
383                    info("IpcServer expiration check interrupted, closing");
384                    close();
385                    break;
386                }
387            }
388        }
389    }
390
391    void close() {
392        closing = true;
393        try {
394            serverSocket.close();
395        } catch (IOException e) {
396            error("Error closing server socket", e);
397        }
398        clients.forEach((s, t) -> {
399            try {
400                s.close();
401            } catch (IOException e) {
402                // ignore
403            }
404            t.interrupt();
405        });
406    }
407
408    static class Waiter {
409        final Context context;
410        final CompletableFuture<Void> future;
411
412        Waiter(Context context, CompletableFuture<Void> future) {
413            this.context = context;
414            this.future = future;
415        }
416    }
417
418    static class Lock {
419
420        final String key;
421
422        List<Context> holders;
423        List<Waiter> waiters;
424
425        Lock(String key) {
426            this.key = key;
427        }
428
429        public synchronized CompletableFuture<Void> lock(Context context) {
430            if (holders == null) {
431                holders = new ArrayList<>();
432            }
433            if (holders.isEmpty() || holders.get(0).shared && context.shared) {
434                holders.add(context);
435                return CompletableFuture.completedFuture(null);
436            }
437            if (waiters == null) {
438                waiters = new ArrayList<>();
439            }
440
441            CompletableFuture<Void> future = new CompletableFuture<>();
442            waiters.add(new Waiter(context, future));
443            return future;
444        }
445
446        public synchronized void unlock(Context context) {
447            if (holders.remove(context)) {
448                while (waiters != null
449                        && !waiters.isEmpty()
450                        && (holders.isEmpty() || holders.get(0).shared && waiters.get(0).context.shared)) {
451                    Waiter waiter = waiters.remove(0);
452                    holders.add(waiter.context);
453                    waiter.future.complete(null);
454                }
455            } else if (waiters != null) {
456                for (Iterator<Waiter> it = waiters.iterator(); it.hasNext(); ) {
457                    Waiter waiter = it.next();
458                    if (waiter.context == context) {
459                        it.remove();
460                        waiter.future.cancel(false);
461                    }
462                }
463            }
464        }
465    }
466
467    class Context {
468
469        final String id;
470        final boolean shared;
471        final List<String> locks = new CopyOnWriteArrayList<>();
472
473        Context(boolean shared) {
474            this.id = String.format("%08x", counter.incrementAndGet());
475            this.shared = shared;
476        }
477
478        public CompletableFuture<?> lock(List<String> keys) {
479            locks.addAll(keys);
480            CompletableFuture<?>[] futures = keys.stream()
481                    .map(k -> IpcServer.this.locks.computeIfAbsent(k, Lock::new))
482                    .map(l -> l.lock(this))
483                    .toArray(CompletableFuture[]::new);
484            return CompletableFuture.allOf(futures);
485        }
486
487        public void unlock() {
488            locks.stream()
489                    .map(k -> IpcServer.this.locks.computeIfAbsent(k, Lock::new))
490                    .forEach(l -> l.unlock(this));
491        }
492    }
493}