Pker.xyz

State Machines in Rust: The Iterator Way

Rust's type system makes it an obvious choice for software that relies on strong guarantees and where correctness is of high importance (I would argue this is actually most (if not all) software).

It's often useful for our code to have verifiable properties. We want to define these in a formal way such that we can guarantee that the properties are respected. Ideally we also want to be able to test (and change) these properties easily. State machines can sometimes help us achieve these goals.

Context

When it comes to state machines in rust a popular pattern is type state (there are already multiple interesting blog posts on the matter, such as this popular one). Type state typically implies having one type per state (i.e a unit struct) and a struct that implements the state transition for every one of these states. This works well and is a fine approach, but can be a bit overkill and cumberstone for lots of use cases.

I was looking at implementing the bitcoin peer to peer network handshake protocol. Just like most protocol handshakes, the handshake consist of a few ordered states where transitions between these states depend on the completion of the previous: a good use case for a state machine, if one wants to formalize the handshake process. You can look at the code for this small project here.

I wanted this code to be extensible and be potentially used as a baseline for implementing more of the peer to peer network protocol. This meant ideally having a lib that binaries can interface with and use, in our specific case this meant exposing a handshake type and allowing to "run/execute" the handshake.

Concretely, I wanted a user to be able to create an instance of an handshake and run it, and I wanted the handshake itself to simply step through the required steps to run the handshake to completion, and return an error with information about when(at which step) and why(underlying error) the handshake failed if it did.

I wanted the user to be able to do something along the lines of:

rust

    // ...
    let mut handshake = Handshake::new(&stream);

    // this attempts to run the handshake to completion
    let handshake_result = handshake.process();

    match handshake_result {
        Ok(_) => {
            tracing::info!("Handshake with {:?} completed.", peer);
        }
        // here the error should contain the step at which it failed and the err itself
        Err(ref e) => {
            tracing::error!("Handshake with {:?} failed. {}", peer, e);
        }
    }
    // close conn and return if the handshake returned an error
    stream.shutdown(std::net::Shutdown::Both)?;
    handshake_result?
                

In the above example we purposefully call handshake.process() outside of the match to be able to close the conn regardless of the result instead of consuming the result within the match and having to duplicate code, but that's a detail and both would work fine.

The interesting bit lies in the process() member function.

Implementation

As we saw our library (a module called handshake) only exposes two functions: new() to create an instance of a handshake and process() to execute the handshake. We want to define states and we want process() to iterate through all the states. Let's first define our states:

rust

/// The states of the bitcoin p2p network protocol handshake, in "order".
#[derive(Debug, Clone, Copy)]
pub enum State {
    Init,
    SendVersion,
    RecvVersion,
    SendAck,
    RecvAck,
    Complete,
}

impl State {
    fn new() -> Self {
        Self::Init
    }

    fn completed(&self) -> bool {
        match self {
            Self::Complete => return true,
            _ => return false,
        }
    }
}
                

An important goal was to have friendly and detailed errors. Let's implement an error type (and the display trait for our State type). We want the error to contain the state and the inner error.

rust

impl fmt::Display for State {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::Init => write!(f, "Initial state"),
            Self::SendVersion => write!(f, "Sending Version Message"),
            Self::RecvVersion => write!(f, "Receiving Version Message"),
            Self::SendAck => write!(f, "Sending Ack Message"),
            Self::RecvAck => write!(f, "Receiving Ack Message"),
            Self::Complete => write!(f, "Handshake completed"),
        }
    }
}

#[derive(Debug)]
/// Top level error, with handshake state context.
pub enum HandshakeError {
    /// An error that happened at the message layer (an inner module which
    /// contains the interesting details of the protocol itself, not important
    /// to this blog post).
    MessageError(State, MessageError),
    /// IO related errors.
    IOError(State, std::io::Error),
}

impl fmt::Display for HandshakeError {
    // NOTE: Relies on the fact that HandshakeError implements Debug
    // and that State implements display.
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::MessageError(state, err) => {
                write!(f, "Message Error: State: {}. err: {}.", state, err)
            }
            Self::IOError(state, err) => write!(f, "IO Error: State: {}. err: {}.", state, err),
        }
    }
}
impl Error for HandshakeError {}
                

Let's now implement our handshake type, which will contain the connection to the peer we want to handshake with and the current state of the handshake. We will also want to implement the process() fn we saw above that will be publicly exposed in our api and the state transition function, which we will call step():

rust

/// Handshake
#[derive(Debug)]
pub struct Handshake<'a> {
    /// The current state of the handshake process.
    state: State,
    /// A tcp stream between a local and remote socket
    /// on which the handshake will happen.
    stream: &'a TcpStream,
}

impl<'a> Handshake<'a> {
    /// Constructs and initializes the handshake that will be attempted on
    /// the given TcpStream.
    pub fn new(stream: &'a TcpStream) -> Self {
        let state = State::new();
        Self { state, stream }
    }

    /// Attempts to run an handshake to completion.
    /// Relies on Handshake to implement the Iterator trait.
    pub fn process(&mut self) -> Result<(), HandshakeError> {
        while let Some(state_result) = self.next() {
            match state_result {
                Ok(_) => {
                    tracing::debug!("Reached state: {}", self.state);
                }
                Err(e) => {
                    return Err(e);
                }
            }
        }

        Ok(())
    }

    /// The state transition function, ignore the method calls, the important
    /// bit is that every step is potentially failible and will return an
    /// handshake error, which contains the current State
    /// and the error itself if it fails.
    fn step(&mut self) -> Result<(), HandshakeError> {
        match self.state {
            State::Init => {
                self.send_version()?;
                self.state = State::SendVersion
            }
            State::SendVersion => {
                self.read_version()?;
                self.state = State::RecvVersion
            }
            State::RecvVersion => {
                self.send_verack()?;
                self.state = State::SendAck
            }
            State::SendAck => {
                self.read_verack()?;
                self.state = State::RecvAck
            }
            State::RecvAck => self.state = State::Complete,
            State::Complete => (),
        }
        Ok(())
    }
                

The interesting/ergonomic snippet is being able to run the handshake step by step by using "while let Some(state_result) = self.next()" and simply return if there is an error. For this we need the important part, which is implementing the iterator trait (which has one essential method, next()):

rust

// Returns `None` when the handshake is completed.
impl Iterator for Handshake<'_> {
    // important: the next() fn will return an optional result: None means
    // the end of the iterator (handshake complete).
    type Item = Result<(), HandshakeError>;

    fn next(&mut self) -> Option {
        if self.state.completed() {
            return None;
        }

        // Step in, using the step() state transition fn defined in the above snippet.
        Some(self.step())
    }
}
                

Annnnnd that's all we really needed. We now have a fairly small process() fn which leverages the iterator, we expose this fn to our users while being able to change the transition steps (and add/remove transitions if we want) as we please behind the curtains.

In this article we looked at a method for implementing state machines in rust by using enums and iterators. This approach is simple, gives a good amount of flexibility and allows chosing to hide/expose the inner state to the consumer without hassle. It also makes testing and appending/removing state easy without changing the api's contract.


Uploaded 03-10-2024