I have a TcpStream
(from the standard library) and would like to write to it from multiple threads. TcpStream
allows this without additional synchronization due to impl Write for &TcpStream
. The payloads are packaged such that I make a single .write_all()
call per payload.
use std::io::Write;
use std::net::TcpStream;
pub struct Publisher {
stream: TcpStream,
}
impl Publisher {
pub fn send(&self, payload: &[u8]) {
// ignore errors for now
let _ = (&self.stream).write_all(payload);
}
}
But does this really work?
My worry is that .write_all()
may involve multiple .write()
calls to send the full payload, and thus concurrent calls may end up interleaving writes from the different threads. I don't see any special handling for TcpStream::write_all
and thus it just uses the default trait implementation.
Is my concern well-founded? Is there a "clever" way to avoid the problem? Or do I simply need to wrap it in a Mutex
regardless?
Yes, you're responsible to synchronize writing to a TcpStream
so your concern is totally well-founded. The fact that it implements Write
for a shared reference is more indicative of the underlying implementations which commonly use just an integer to refer to open TcpStreams
and thus don't need any references or Rust struct mutability to write to them.
In fact, there is nothing in the documentation that suggests write_all
s to a TcpStream
are synchronized. Reviewing the code doesn't reveal any internal synchronization either. And indeed, you can observe interleaving with the following test program with a sufficiently large N
:
use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
const N: usize = 50_000_000;
fn main() -> Result<(), std::io::Error> {
let listener = TcpListener::bind(("127.0.0.1", 0))?;
let address = listener.local_addr()?;
let handle = thread::spawn(read(listener));
let stream = TcpStream::connect(address)?;
thread::scope(|s| {
s.spawn(write(b'a', &stream));
s.spawn(write(b'b', &stream));
});
handle.join().unwrap();
Ok(())
}
fn read(listener: TcpListener) -> impl FnOnce() {
move || {
while let Ok((mut s, _)) = listener.accept() {
let mut chars = HashMap::new();
let mut buf = [0u8; 1024];
while let Ok(n) = s.read(&mut buf) {
for &c in &buf[..n] {
chars.entry(c).and_modify(|v| *v += 1).or_insert(1);
}
if ![Some(N), None].contains(&chars.get(&b'a').copied())
&& ![Some(N), None].contains(&chars.get(&b'b').copied())
{
if chars[&b'a'] > chars[&b'b'] {
eprintln!("received 'b' before done receiving 'a' {chars:?}");
} else {
eprintln!("received 'a' before done receiving 'b' {chars:?}");
}
return;
}
}
}
}
}
fn write(c: u8, mut s: &TcpStream) -> impl FnOnce() + '_ {
let data = vec![c; N];
move || {
_ = s.write_all(&data);
}
}
I would use a Mutex
to add the required synchronization.
The same applies to the other types which have an impl Write for &T
in the standard library: