【Rust】90 行代码实现 AsyncTeeReader

发布日期:分类:Programming

大家好,好久不见,这里是某昨。

这段 AsyncTeeReader 的代码是我在实现 anni-backend 的时候写的。虽然之后没有用到直接全删了,但个人感觉之后肯定还会有用到的时候,并且在书写这段代码的过程中我也学到了很多东西,故撰写此文以作记录。

代码带简单注释、newuse 和空行一共 90 行,从中或许可以稍微了解一下异步的代码该怎么写(笑)

完整源码可在 Rust Playground[1] 查看。

什么是 TEE?

TEE(1) 的功能是从标准输入读入,并将结果同时写入指定文件标准输出[2]。基本流程如下图所示:

图中的实线表示的是真实的数据流向,而虚线则表示了对于 Program B 而言的数据流向。对 Program B 而言,它不知道自己和 Program A 之间增加了一个中间人,其处理的数据也和直接从 A 处接收的数据完全一致。

至此,我们可以将 TEE 的功能归结为如下两条:

  1. 将接收到的数据原封不动地传递给原接收者
  2. 将数据的副本转发另一个(些)接收者。

Reader?

还是从这张图来看,不过这次我们换一个角度。

从数据的流向上来看,Program ATEEstdin 传送数据,和 TEEProgram Astdout 读取数据,二者是等价的。因此,我们可以把 TEEProgram A 接收数据看作是 TEEProgram AStdout Reader 中读取数据;同理,我们把 TEE 向文件写入看作是向文件的 Writer 写入

Program B 原本是直接从 Program ARead 数据,现在改从 TEERead 数据,因此 TEEProgram B 而言表现成的是一个 Reader

异步读写 Trait

在实现之前,我们需要了解这次我们需要用到的两个 Trait

tokio::io::AsyncRead

TokioAsyncRead 定义如下:

pub trait AsyncRead {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>>;
}

这个 Trait 只定义了一个函数:poll_read。和 Futurefn poll 相比,它只多了一个参数:buf: &mut ReadBuf<'_>,用来存储读取到的内容

当读取成功时,Poll 返回 Poll::Ready(Ok(()));当读取出现问题时,返回 Poll::Ready(Err(e));当 Poll 未完成时,返回 Poll::Pending

poll_read 读取遇到 EOF 时,调用前后的 buf 已占用大小不变,隐式地提示调用者遇到了 EOF

tokio::io::AsyncWrite

AsyncRead 相比,AsyncWrite 相对就复杂一些了。这里省略了带有默认实现的一部分,其余如下所示:

pub trait AsyncWrite {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>>;

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>;

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>;
}

首先是 poll_write。和 poll_read 相比,第三个参数换成了 &[u8],表示待写入的数据;返回类型也不再是 (),而是代表了成功写入的大小(usize)。

其次是 poll_flush,顾名思义,是在写入结束时 flush 用的。

最后是 poll_shutdown。在写入过程出现问题时,可以调用这个函数,尝试终止写入过程。当终止成功时,返回 Ok(())

开始实现

定义

有了上面的基础,我们可以写出 struct 的定义:

struct AsyncTeeReader {
    reader: Pin<Box<dyn AsyncRead>>,
    writer: Pin<Box<dyn AsyncWrite>>,

    state: AsyncTeeReaderState,
    buf: Vec<u8>,
    buf_now: usize;
}

enum AsyncTeeReaderState {
    Reading,
    Writing,
    Flushing,
}

AsyncTeeReader 中,我们记录了读取所需的 reader、写入所需的 writer、当前的工作状态,以及一个 buffer 和它当前写入的位置。

工作状态分为 ReadingWritingFlushing 三种,分别对应 reader.poll_readwriter.poll_writewriter.poll_flush 三种情况下 Pending 的再调用;而 buf 则用于暂时缓存 poll_read 读取到的数据,延迟到 poll_write 结束之后再返回。

Reading

AsyncTeeReaderState::Reading => {
    // step 1: read
    let begin = buf.filled().len();
    let ret = self.reader.as_mut().poll_read(cx, buf);
    match ret {
        Poll::Ready(Ok(())) => {
            let end = buf.filled().len();
            if begin == end {
                // EOF, flush
                self.state = AsyncTeeReaderState::Flushing;
            } else {
                // Write
                self.buf.extend_from_slice(&buf.filled()[begin..end]);
                buf.set_filled(begin);
                self.state = AsyncTeeReaderState::Writing;
            }
            // wake immediately to finish the last part
            cx.waker().wake_by_ref();
            // return pending
            Poll::Pending
        }
        Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
        Poll::Pending => Poll::Pending,
    }
}

在读取开始之前,我们记下 buf 的已用容量 begin,在之后判断是否遇到 EOF 的时候会用到。随后就是调用一次 readerpoll_read。当正常 Poll::Ready 时,执行处理逻辑;当遇到错误时,直接返回错误;而当状态为 Poll::Pending 时,直接返回 Pending

值得注意的是,这里我们之所以可以直接返回 Pending 状态,是因为我们在调用过程中把 cx 传递给了 reader.poll_read。因此,当 reader 可供读取时,cx.waker 会被 reader 调用,因此我们就不需要去关心什么时候 wake 的问题了。

当正常返回 Poll::Ready(Ok()) 时,我们再读取一次 buf 的已用容量,记作 end。如果 beginend 不相等,说明 poll_read 读到了数据,我们把数据暂时存在 self.buf 里,并将状态修改为 Writing;如果 beginend 相等,说明 reader 读到了 EOF,此时没有新的数据,只需要等 writerflush 结束即可,因此将状态修改为 Flushing

在上述判断结束之后,由于一轮还没有结束(Reading-WritingReading-Flushing),因此返回的状态是 Poll::Pending。但如果我们直接返回,那么之后就不会有 wake 的机会了(想一想,为什么?)。因此这里我们需要手动调用一次 cx.waker().wake_by_ref(),使得我们在返回 Poll::Pending 之后能够开始执行下一轮的任务。

Writing

AsyncTeeReaderState::Writing => {
    let me = self.get_mut();
    let ret = me.writer.as_mut().poll_write(cx, &me.buf[me.buf_now..]);
    match ret {
        Poll::Ready(Ok(written)) => {
            me.buf_now += written;
            if me.buf.len() != me.buf_now {
                // partial written
                cx.waker().wake_by_ref();
                Poll::Pending
            } else {
                // fully written, read again
                buf.put_slice(&me.buf);
                me.buf.clear();
                me.buf_now = 0;
                me.state = AsyncTeeReaderState::Reading;
                Poll::Ready(Ok(()))
            }
        }
        Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
        Poll::Pending => Poll::Pending,
    }
}

进入 Writing 状态,过程和 Reading 也大同小异。这里我们需要额外判断的是 Writer部分写入问题。poll_write 返回了写入的大小,我们需要将其与 self.buf 的大小作比对。当完全写入时,自然可以清空缓存,将 self.buf 重新写回 buf,并返回 AsyncReaderReady 状态;而当部分写入时,我们则需要返回 Pending 状态,并手动触发 waker 以确保 poll_read 能够被再次调用。在部分写入之后的下一次写入中,我们也只能向其传递没有写入的内容,需要跳过已写入的部分

Flushing

AsyncTeeReaderState::Flushing => self.writer.as_mut().poll_flush(cx)

AsyncTeeReader 进入这个状态时,读的过程也进入了尾声。我们只需要等待 writerflush 完成就可以了。

Links

  1. https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=9155c923388f8d57da69ecd3c4b9a382
  2. https://man7.org/linux/man-pages/man1/tee.1.html

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注