Update: As of jetty-9.4.15.v20190215 support for port unification is built into Jetty; see this answer.
Yes We Can
This is possible and we have done it. The code here works with Jetty 8; I have not tested with Jetty 9, but this answer has similar code for Jetty 9.
By the way this is called port unification, and it is has apparently long been supported in Glassfish using Grizzly.
Outline
The basic idea is to produce an implementation of org.eclipse.jetty.server.Connector
which can look ahead at the first byte of the client's request. Luckily both HTTP and HTTPS have the client start the communication. For HTTPS (and TLS/SSL generally) the first byte will be 0x16
(TLS), or >= 0x80
(SSLv2). For HTTP the first byte will be good-old printable 7-bit ASCII. Now, depending on the first byte, the Connector
will either produce an SSL connection or a plain connection.
In the code here we take advantage of the fact that Jetty's SslSelectChannelConnector
itself extends SelectChannelConnector
, and has a newPlainConnection()
method (calling its superclass to produce a non-SSL connection) as well as a newConnection()
method (to produce an SSL connection). So our new Connector
can extend SslSelectChannelConnector
and delegate to one of those methods after observing the first byte from the client.
Unfortunately, we will be expected to create an instance of AsyncConnection
before the first byte is available. Some methods of that instance may even be called before the first byte is available. So we create a LazyConnection implements AsyncConnection
which can figure out later which kind of connection it will delegate to, or even return sensible default responses to some methods before it knows.
Being NIO-based, our Connector
will work with a SocketChannel
. Luckily we can extend SocketChannel
to create a ReadAheadSocketChannelWrapper
which delegates to the "real" SocketChannel
but can inspect and store the first bytes of the client's message.
Some Details
One very hacky bit. One of the methods our Connector
must override is customize(Endpoint,Request)
. If we end up with an SSL-based Endpoint
we can just pass to our superclass; otherwise the superclass will throw a ClassCastException
, but only after both passing to its superclass and setting the scheme on the Request
. So we pass to the superclass, but undo setting the scheme when we see the exception.
We also override isConfidential()
and isIntegral()
to ensure that our servlets can properly use HttpServletRequest.isSecure()
to figure out whether HTTP or HTTPS was used.
Attempting to read the first byte from the client may throw an IOException
, but we may have to attempt that in a place where an IOException
isn't expected, in which case we keep the exception around and throw it later.
Extending SocketChannel
looks different in Java >= 7 and Java 6. In the latter case, just comment out the methods that Java 6 SocketChannel
doesn't have.
The Code
public class PortUnificationSelectChannelConnector extends SslSelectChannelConnector {
public PortUnificationSelectChannelConnector() {
super();
}
public PortUnificationSelectChannelConnector(SslContextFactory sslContextFactory) {
super(sslContextFactory);
}
@Override
protected SelectChannelEndPoint newEndPoint(SocketChannel channel, SelectSet selectSet, SelectionKey key) throws IOException {
return super.newEndPoint(new ReadAheadSocketChannelWrapper(channel, 1), selectSet, key);
}
@Override
protected AsyncConnection newConnection(SocketChannel channel, AsyncEndPoint endPoint) {
return new LazyConnection((ReadAheadSocketChannelWrapper)channel, endPoint);
}
@Override
public void customize(EndPoint endpoint, Request request) throws IOException {
String scheme = request.getScheme();
try {
super.customize(endpoint, request);
} catch (ClassCastException e) {
request.setScheme(scheme);
}
}
@Override
public boolean isConfidential(Request request) {
if (request.getAttribute("javax.servlet.request.cipher_suite") != null) return true;
else return isForwarded() && request.getScheme().equalsIgnoreCase(HttpSchemes.HTTPS);
}
@Override
public boolean isIntegral(Request request) {
return isConfidential(request);
}
class LazyConnection implements AsyncConnection {
private final ReadAheadSocketChannelWrapper channel;
private final AsyncEndPoint endPoint;
private final long timestamp;
private AsyncConnection connection;
public LazyConnection(ReadAheadSocketChannelWrapper channel, AsyncEndPoint endPoint) {
this.channel = channel;
this.endPoint = endPoint;
this.timestamp = System.currentTimeMillis();
this.connection = determineNewConnection(channel, endPoint, false);
}
public Connection handle() throws IOException {
if (connection == null) {
connection = determineNewConnection(channel, endPoint, false);
channel.throwPendingException();
}
if (connection != null) return connection.handle();
else return this;
}
public long getTimeStamp() {
return timestamp;
}
public void onInputShutdown() throws IOException {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onInputShutdown();
}
public boolean isIdle() {
if (connection == null) connection = determineNewConnection(channel, endPoint, false);
if (connection != null) return connection.isIdle();
else return false;
}
public boolean isSuspended() {
if (connection == null) connection = determineNewConnection(channel, endPoint, false);
if (connection != null) return connection.isSuspended();
else return false;
}
public void onClose() {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onClose();
}
public void onIdleExpired(long l) {
if (connection == null) connection = determineNewConnection(channel, endPoint, true);
connection.onIdleExpired(l);
}
AsyncConnection determineNewConnection(ReadAheadSocketChannelWrapper channel, AsyncEndPoint endPoint, boolean force) {
byte[] bytes = channel.getBytes();
if ((bytes == null || bytes.length == 0) && !force) return null;
if (looksLikeSsl(bytes)) {
return PortUnificationSelectChannelConnector.super.newConnection(channel, endPoint);
} else {
return PortUnificationSelectChannelConnector.super.newPlainConnection(channel, endPoint);
}
}
// TLS first byte is 0x16
// SSLv2 first byte is >= 0x80
// HTTP is guaranteed many bytes of ASCII
private boolean looksLikeSsl(byte[] bytes) {
if (bytes == null || bytes.length == 0) return false; // force HTTP
byte b = bytes[0];
return b >= 0x7F || (b < 0x20 && b != '\n' && b != '\r' && b != '\t');
}
}
static class ReadAheadSocketChannelWrapper extends SocketChannel {
private final SocketChannel channel;
private final ByteBuffer start;
private byte[] bytes;
private IOException pendingException;
private int leftToRead;
public ReadAheadSocketChannelWrapper(SocketChannel channel, int readAheadLength) throws IOException {
super(channel.provider());
this.channel = channel;
start = ByteBuffer.allocate(readAheadLength);
leftToRead = readAheadLength;
readAhead();
}
public synchronized void readAhead() throws IOException {
if (leftToRead > 0) {
int n = channel.read(start);
if (n == -1) {
leftToRead = -1;
} else {
leftToRead -= n;
}
if (leftToRead <= 0) {
start.flip();
bytes = new byte[start.remaining()];
start.get(bytes);
start.rewind();
}
}
}
public byte[] getBytes() {
if (pendingException == null) {
try {
readAhead();
} catch (IOException e) {
pendingException = e;
}
}
return bytes;
}
public void throwPendingException() throws IOException {
if (pendingException != null) {
IOException e = pendingException;
pendingException = null;
throw e;
}
}
private int readFromStart(ByteBuffer dst) throws IOException {
int sr = start.remaining();
int dr = dst.remaining();
if (dr == 0) return 0;
int n = Math.min(dr, sr);
dst.put(bytes, start.position(), n);
start.position(start.position() + n);
return n;
}
public synchronized int read(ByteBuffer dst) throws IOException {
throwPendingException();
readAhead();
if (leftToRead > 0) return 0;
int sr = start.remaining();
if (sr > 0) {
int n = readFromStart(dst);
if (n < sr) return n;
}
return sr + channel.read(dst);
}
public synchronized long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
throwPendingException();
if (offset + length > dsts.length || length < 0 || offset < 0) {
throw new IndexOutOfBoundsException();
}
readAhead();
if (leftToRead > 0) return 0;
int sr = start.remaining();
int newOffset = offset;
if (sr > 0) {
int accum = 0;
for (; newOffset < offset + length; newOffset++) {
accum += readFromStart(dsts[newOffset]);
if (accum == sr) break;
}
if (accum < sr) return accum;
}
return sr + channel.read(dsts, newOffset, length - newOffset + offset);
}
public int hashCode() {
return channel.hashCode();
}
public boolean equals(Object obj) {
return channel.equals(obj);
}
public String toString() {
return channel.toString();
}
public Socket socket() {
return channel.socket();
}
public boolean isConnected() {
return channel.isConnected();
}
public boolean isConnectionPending() {
return channel.isConnectionPending();
}
public boolean connect(SocketAddress remote) throws IOException {
return channel.connect(remote);
}
public boolean finishConnect() throws IOException {
return channel.finishConnect();
}
public int write(ByteBuffer src) throws IOException {
return channel.write(src);
}
public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
return channel.write(srcs, offset, length);
}
@Override
protected void implCloseSelectableChannel() throws IOException {
channel.close();
}
@Override
protected void implConfigureBlocking(boolean block) throws IOException {
channel.configureBlocking(block);
}
// public SocketAddress getLocalAddress() throws IOException {
// return channel.getLocalAddress();
// }
//
// public <T> T getOption(java.net.SocketOption<T> name) throws IOException {
// return channel.getOption(name);
// }
//
// public Set<java.net.SocketOption<?>> supportedOptions() {
// return channel.supportedOptions();
// }
//
// public SocketChannel bind(SocketAddress local) throws IOException {
// return channel.bind(local);
// }
//
// public SocketAddress getRemoteAddress() throws IOException {
// return channel.getRemoteAddress();
// }
//
// public <T> SocketChannel setOption(java.net.SocketOption<T> name, T value) throws IOException {
// return channel.setOption(name, value);
// }
//
// public SocketChannel shutdownInput() throws IOException {
// return channel.shutdownInput();
// }
//
// public SocketChannel shutdownOutput() throws IOException {
// return channel.shutdownOutput();
// }
}
}