From e1726b4ccb0b51f527f1bf7a5a39c027294f5174 Mon Sep 17 00:00:00 2001 Subject: 8314063: The socket is not closed in Connection::createSocket when the handshake failed for LDAP connection --- .../classes/com/sun/jndi/ldap/Connection.java | 131 +++++---- .../ldap/LdapSSLHandshakeFailureTest.java | 249 ++++++++++++++++++ test/jdk/com/sun/jndi/ldap/ksWithSAN | 0 3 files changed, 314 insertions(+), 66 deletions(-) create mode 100644 test/jdk/com/sun/jndi/ldap/LdapSSLHandshakeFailureTest.java create mode 100644 test/jdk/com/sun/jndi/ldap/ksWithSAN diff --git a/src/java.naming/share/classes/com/sun/jndi/ldap/Connection.java b/src/java.naming/share/classes/com/sun/jndi/ldap/Connection.java index ebb21bd8b..f71b1bb14 100644 --- a/src/java.naming/share/classes/com/sun/jndi/ldap/Connection.java +++ b/src/java.naming/share/classes/com/sun/jndi/ldap/Connection.java @@ -280,79 +280,79 @@ public final class Connection implements Runnable { private Socket createSocket(String host, int port, String socketFactory, int connectTimeout) throws Exception { - Socket socket = null; - - if (socketFactory != null) { + SocketFactory factory = getSocketFactory(socketFactory); + assert factory != null; + Socket socket = createConnectionSocket(host, port, factory, connectTimeout); - // create the factory + // the handshake for SSL connection with server and reset timeout for the socket + if (socket instanceof SSLSocket sslSocket) { + try { + initialSSLHandshake(sslSocket, connectTimeout); + } catch (Exception e) { + // 8314063 the socket is not closed after the failure of handshake + // close the socket while the error happened + closeOpenedSocket(socket); + throw e; + } + } + return socket; + } + private SocketFactory getSocketFactory(String socketFactoryName) throws Exception { + if (socketFactoryName == null) { + if (debug) { + System.err.println("Connection: using default SocketFactory"); + } + return SocketFactory.getDefault(); + } else { + if (debug) { + System.err.println("Connection: loading supplied SocketFactory: " + socketFactoryName); + } @SuppressWarnings("unchecked") Class socketFactoryClass = - (Class)Obj.helper.loadClass(socketFactory); + (Class) Obj.helper.loadClass(socketFactoryName); Method getDefault = - socketFactoryClass.getMethod("getDefault", new Class[]{}); + socketFactoryClass.getMethod("getDefault"); SocketFactory factory = (SocketFactory) getDefault.invoke(null, new Object[]{}); + return factory; + } + } - // create the socket - - if (connectTimeout > 0) { - - InetSocketAddress endpoint = - createInetSocketAddress(host, port); - - // unconnected socket - socket = factory.createSocket(); - - if (debug) { - System.err.println("Connection: creating socket with " + - "a timeout using supplied socket factory"); - } - - // connected socket - socket.connect(endpoint, connectTimeout); - } - - // continue (but ignore connectTimeout) - if (socket == null) { - if (debug) { - System.err.println("Connection: creating socket using " + - "supplied socket factory"); - } - // connected socket - socket = factory.createSocket(host, port); - } - } else { - - if (connectTimeout > 0) { - - InetSocketAddress endpoint = createInetSocketAddress(host, port); - - socket = new Socket(); + private Socket createConnectionSocket(String host, int port, SocketFactory factory, + int connectTimeout) throws Exception { + Socket socket = null; - if (debug) { - System.err.println("Connection: creating socket with " + - "a timeout"); - } - socket.connect(endpoint, connectTimeout); + if (connectTimeout > 0) { + // create unconnected socket and then connect it if timeout + // is supplied + InetSocketAddress endpoint = + createInetSocketAddress(host, port); + // unconnected socket + socket = factory.createSocket(); + // connect socket with a timeout + socket.connect(endpoint, connectTimeout); + if (debug) { + System.err.println("Connection: creating socket with " + + "a connect timeout"); } - - // continue (but ignore connectTimeout) - - if (socket == null) { - if (debug) { - System.err.println("Connection: creating socket"); - } - // connected socket - socket = new Socket(host, port); + } + if (socket == null) { + // create connected socket + socket = factory.createSocket(host, port); + if (debug) { + System.err.println("Connection: creating connected socket with" + + " no connect timeout"); } } + return socket; + } + + // For LDAP connect timeouts on LDAP over SSL connections must treat + // the SSL handshake following socket connection as part of the timeout. + // So explicitly set a socket read timeout, trigger the SSL handshake, + // then reset the timeout. + private void initialSSLHandshake(SSLSocket sslSocket , int connectTimeout) throws Exception { - // For LDAP connect timeouts on LDAP over SSL connections must treat - // the SSL handshake following socket connection as part of the timeout. - // So explicitly set a socket read timeout, trigger the SSL handshake, - // then reset the timeout. - if (socket instanceof SSLSocket) { - SSLSocket sslSocket = (SSLSocket) socket; if (!IS_HOSTNAME_VERIFICATION_DISABLED) { SSLParameters param = sslSocket.getSSLParameters(); param.setEndpointIdentificationAlgorithm("LDAPS"); @@ -365,8 +365,6 @@ public final class Connection implements Runnable { sslSocket.startHandshake(); sslSocket.setSoTimeout(socketTimeout); } - } - return socket; } //////////////////////////////////////////////////////////////////////////// @@ -642,7 +640,7 @@ public final class Connection implements Runnable { flushAndCloseOutputStream(); // 8313657 socket is not closed until GC is run - closeOpenedSocket(); + closeOpenedSocket(sock); tryUnpauseReader(); if (!notifyParent) { @@ -695,9 +693,10 @@ public final class Connection implements Runnable { } // close socket - private void closeOpenedSocket() { + private void closeOpenedSocket(Socket socket) { try { - sock.close(); + if (socket != null && !socket.isClosed()) + socket.close(); } catch (IOException ioEx) { if (debug) { System.err.println("Connection.closeConnectionSocket: Socket close problem: " + ioEx); diff --git a/test/jdk/com/sun/jndi/ldap/LdapSSLHandshakeFailureTest.java b/test/jdk/com/sun/jndi/ldap/LdapSSLHandshakeFailureTest.java new file mode 100644 index 000000000..29f74d250 --- /dev/null +++ b/test/jdk/com/sun/jndi/ldap/LdapSSLHandshakeFailureTest.java @@ -0,0 +1,249 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.test.lib.net.URIBuilder; + +import javax.naming.Context; +import javax.naming.ldap.InitialLdapContext; +import javax.naming.ldap.LdapContext; +import javax.net.SocketFactory; +import javax.net.ssl.SSLServerSocketFactory; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; +import java.util.Hashtable; + +/* + * @test + * @bug 8314063 + * @library /test/lib + * @summary For LDAPs connection, if the value of com.sun.jndi.ldap.connect.timeout is + * set too small or not an optimal value for the system, after the socket is created and + * connected to the server, but the handshake between the client and server fails due to + * socket time out, the opened socket is not closed properly. In this test case, the server + * is forced to sleep ten seconds and connection time out for client is one second. This + * will allow the socket opened and connected, and give the chance for the handshake to be + * timed out. Before this fix, the socket is kept opened. Right now the exception will be + * caught and the socket will be closed. + * + * @run main/othervm LdapSSLHandshakeFailureTest LdapSSLHandshakeFailureTest$CustomSocketFactory true 6000 + * @run main/othervm LdapSSLHandshakeFailureTest -1000 true 6000 + * @run main/othervm LdapSSLHandshakeFailureTest -1000 false 6000 + * @run main/othervm LdapSSLHandshakeFailureTest 2000 false 6000 + * @run main/othervm LdapSSLHandshakeFailureTest 0 true 6000 + * @run main/othervm LdapSSLHandshakeFailureTest 0 false 6000 + * @run main/othervm LdapSSLHandshakeFailureTest true + * @run main/othervm LdapSSLHandshakeFailureTest false + */ + +public class LdapSSLHandshakeFailureTest { + private static String SOCKET_CLOSED_MSG = "The socket has been closed."; + + private static int serverSleepingTime = 5000; + + public static void main(String args[]) throws Exception { + + // Set the keystores + setKeyStore(); + boolean serverSlowDown = Boolean.valueOf(args[0]); + if (args.length == 2) { + serverSlowDown = Boolean.valueOf(args[1]); + } + + if (args.length == 3) { + serverSleepingTime = Integer.valueOf(args[2]); + } + + boolean hasCustomSocketFactory = args[0] + .equals("LdapSSLHandshakeFailureTest$CustomSocketFactory"); + // start the test server first. + try (TestServer server = new TestServer(serverSlowDown, serverSleepingTime)) { + server.start(); + Hashtable env = new Hashtable<>(); + env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory"); + env.put("java.naming.ldap.version", "3"); + env.put(Context.PROVIDER_URL, URIBuilder.newBuilder() + .scheme("ldaps") + .loopback() + .port(server.getPortNumber()) + .buildUnchecked().toString()); + + if (hasCustomSocketFactory) { + env.put("java.naming.ldap.factory.socket", args[0]); + env.put("com.sun.jndi.ldap.connect.timeout", "1000"); + } + + if (args.length == 2 && !hasCustomSocketFactory) { + env.put("com.sun.jndi.ldap.connect.timeout", args[0]); + } + + env.put(Context.SECURITY_PROTOCOL, "ssl"); + env.put(Context.SECURITY_AUTHENTICATION, "Simple"); + env.put(Context.SECURITY_PRINCIPAL, "cn=principal"); + env.put(Context.SECURITY_CREDENTIALS, "123456"); + LdapContext ctx = null; + try { + ctx = new InitialLdapContext(env, null); + } catch (Exception e) { + if (CustomSocketFactory.customSocket.closeMethodCalledCount() > 0 + && hasCustomSocketFactory + && Boolean.valueOf(args[1])) { + System.out.println(SOCKET_CLOSED_MSG); + } else { + throw e; + } + } finally { + if (ctx != null) + ctx.close(); + } + } + } + + public static class CustomSocketFactory extends SocketFactory { + private static CustomSocket customSocket; + + public static CustomSocketFactory getDefault() { + return new CustomSocketFactory(); + } + + @Override + public Socket createSocket() throws SocketException { + customSocket = new CustomSocket(); + return customSocket; + } + + @Override + public Socket createSocket(String s, int timeout) { + return customSocket; + } + + @Override + public Socket createSocket(String host, int port, InetAddress localHost, + int localPort) { + return customSocket; + } + + @Override + public Socket createSocket(InetAddress host, int port) { + return customSocket; + } + + @Override + public Socket createSocket(InetAddress address, int port, + InetAddress localAddress, int localPort) { + return customSocket; + } + } + + private static class CustomSocket extends Socket { + private int closeMethodCalled = 0; + + public CustomSocket() { + closeMethodCalled = 0; + } + + public int closeMethodCalledCount() { + return closeMethodCalled; + } + + @Override + public void close() throws java.io.IOException { + closeMethodCalled++; + super.close(); + } + } + + private static void setKeyStore() { + + String fileName = "ksWithSAN", dir = System.getProperty("test.src", ".") + File.separator; + + System.setProperty("javax.net.ssl.keyStore", dir + fileName); + System.setProperty("javax.net.ssl.keyStorePassword", "welcome1"); + System.setProperty("javax.net.ssl.trustStore", dir + fileName); + System.setProperty("javax.net.ssl.trustStorePassword", "welcome1"); + } + + static class TestServer extends Thread implements AutoCloseable { + private boolean isForceToSleep; + private int sleepingTime; + private final ServerSocket serverSocket; + private final int PORT; + + private TestServer(boolean isForceToSleep, int sleepingTime) { + this.isForceToSleep = isForceToSleep; + this.sleepingTime = sleepingTime; + try { + SSLServerSocketFactory socketFactory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); + serverSocket = socketFactory.createServerSocket(0, 0, InetAddress.getLoopbackAddress()); + PORT = serverSocket.getLocalPort(); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + setDaemon(true); + } + + public int getPortNumber() { + return PORT; + } + + @Override + public void run() { + try (Socket socket = serverSocket.accept(); + InputStream in = socket.getInputStream(); + OutputStream out = socket.getOutputStream()) { + if (isForceToSleep) { + Thread.sleep(sleepingTime); + } + byte[] bindResponse = {0x30, 0x0C, 0x02, 0x01, 0x01, 0x61, 0x07, 0x0A, + 0x01, 0x00, 0x04, 0x00, 0x04, 0x00}; + // read the bindRequest + while (in.read() != -1) { + in.skip(in.available()); + break; + } + out.write(bindResponse); + out.flush(); + // ignore the further requests + while (in.read() != -1) { + in.skip(in.available()); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + + @Override + public void close() throws Exception { + if (serverSocket != null) { + serverSocket.close(); + } + } + } +} + + diff --git a/test/jdk/com/sun/jndi/ldap/ksWithSAN b/test/jdk/com/sun/jndi/ldap/ksWithSAN new file mode 100644 index 000000000..e69de29bb -- 2.22.0