If a Java application creates a ServerSocket
that accepts TCP connections, is there a way to restrict which processes are allowed to connect to it?
For example, this is my current code:
ServerSocket serverSocket = new ServerSocket(1234);
Socket socket = serverSocket.accept();
and I want to make sure that other devices on my network and even other processes running on the same machine are not able to connect to it (it would be a security risk if they did). I was able to solve the former by binding serverSocket
only to the loopback address (checking if socket.getRemoteAddress()
points to the local host would work too) but I couldn't find a way to restrict it to my current process.
This is even more of a problem when doing it on Android. In my application, I want to create a WebView
(owned by my process) and point it to serverSocket
but I don't want others apps to be able to connect to it.
Is there a way to solve this problem?
I don't think that you can prevent other processes from connecting to the ServerSocket but you accept a connection you can definitely determine if it belongs to you or to some other process. The first step is figure out if the connection originated from localhost:
InetSocketAddress remoteAddress = (InetSocketAddress) socket.getRemoteSocketAddress();
String hostname = remoteAddress.getHostName();
if (!hostname.equals("localhost")) { socket.close(); }
Alternatively you can bind the socket to a loopback address like 127.0.0.1 or 0.0.0.0 (like EJP mentioned) and skip this step. Once you know that the connection came from localhost all you have to do is find the remote port and figure out if your process owns it.
int remotePort = remoteAddress.getPort();
if (ownPort(remotePort) == 1) { socket.close(); }
As far as I know, Java doesn't have an API that you can use to list your process ports but you can definitely do that via JNI. On the Java side you would need something like:
private native int doOwnPort(int port);
And on the native side:
JNIEXPORT jint JNICALL Java_something_doOwnPort(JNIEnv *env, jobject object, jint port) {
long totalFDs = getdtablesize();
struct sockaddr_in sa;
struct stat sb;
// Iterate through all file descriptors
for (int i = 0; i < totalFDs; ++i) {
// Check if "i" is a valid FD
memset(&sb, 0, sizeof(sb));
if (fstat(i, &sb) < 0)
continue;
// Check if "i" is a socket
if (!S_ISSOCK(sb.st_mode))
continue;
// Get local address of socket with FD "i"
memset(&sa, 0, sizeof(sa));
socklen_t sa_len = sizeof(sa);
getsockname(i, (struct sockaddr*) &sa, &sa_len);
// Check if the port matches
if (sa.sin_port == port)
return 1; // We own the port
}
return -1; // We don't own the port
}
PS: This code is for Linux but should work on Android/Windows/OSX too.
Maybe there is a more direct/efficient way to check if the port is owned by the current process without having to iterate through the FD table but that's a separate problem. HTH!