Search code examples
rusterror-handlingdependency-injectiontraitsrust-axum

Handling dependency-injection and errors in Rust, using Axum


I'm trying to wrap my head around how to handle this particular issue in Rust. I've programmed in Rust for a while, but I'm primarily a C# developer, and some of my knowledge in that language might be confusing me with this problem.

I have a web application built in Axum where I'm building a data-access layer to abstract away direct sqlx connections. I'm attempting to build all my controller objects in a single State passed around as traits to support dependency injection.

Everything's been working just fine so far - generally I'm wrapping the dyn traits in Arcs and requiring them to implement Send + Sync and Axum is perfectly fine passing them from handler to handler.

Example:

    #[async_trait]
    pub trait DataLayer : Send + Sync {
        async fn try_get_user_by_id<'a>(&self, user_id: &'a str) -> Option<UserDbModel>;
        async fn get_user_by_email<'a>(&self, email: &'a str) -> Option<UserDbModel>;

        async fn get_refr_token_by_token<'a>(&self, token: &'a str) -> Option<RefrTokenDbModel>, BoxError>;
        async fn get_refr_token_by_id(&self, token: i32) -> Option<RefrTokenDbModel>;
        async fn create_refr_token(&self, refr_token: CreateRefrTokenDbModel) -> u64;
        async fn revoke_refr_token(&self, token: RevokeRefrTokenDbModel);
    }

Then, this DataLayer trait can be referenced in my other services

    #[async_trait]
    pub trait AuthService: Send + Sync {
        async fn try_accept_creds(&self, info: LoginPayload) -> login_error::Result<TokensModel>;
        async fn try_accept_refresh(&self, refr_token: String) -> refresh_error::Result<TokensModel>;
    }

    #[derive(Clone)]
    pub struct CoreAuthService {
        data_layer: Arc<dyn DataLayer>,
        token_service: Arc<dyn TokenService>,
    }

A big problem with the DataLayer trait, however (as you might be able to see) is I originally set it up to just panic!() when it hit some kind of database error. I'd like to be able to have each return value in the trait methods to be wrapped in a Result.

The problem I'm hitting is that I want to ensure this Error type is generic to whatever the implementation uses. So naturally I tried to create a type in the trait:

    #[async_trait]
    pub trait DataLayer : Send + Sync {
        type Error : std::error::Error + Send + Sync;

        async fn try_get_user_by_id<'a>(&self, user_id: &'a str) -> Result<Option<UserDbModel>, Self::Error>;
        async fn get_user_by_email<'a>(&self, email: &'a str) -> Result<Option<UserDbModel>, Self::Error>;

        async fn get_refr_token_by_token<'a>(&self, token: &'a str) -> Result<Option<RefrTokenDbModel>, Self::Error>;
        async fn get_refr_token_by_id(&self, token: i32) -> Result<Option<RefrTokenDbModel>, Self::Error>;
        async fn create_refr_token(&self, refr_token: CreateRefrTokenDbModel) -> Result<u64, Self::Error>;
        async fn revoke_refr_token(&self, token: RevokeRefrTokenDbModel) -> Result<(), Self::Error>;
    }

Then I could define the Error type however I wanted:

    pub struct DbDataLayer {
        db: MySqlPool,
        settings: TokenSettings
    }

    #[async_trait]
    impl DataLayer for DbDataLayer {
        type Error = sqlx::Error;
        async fn try_get_user_by_id<'a>(&self, user_id: &'a str) -> sqlx::Result<Option<UserDbModel>> {
            let user = sqlx::query_as!(UserDbModel, r"
                SELECT id, email, password_hash as pwd_hash, role FROM users
                WHERE id = ?
            ", user_id).fetch_one(&self.db).await;

            match user {
                Ok(user) => Ok(Some(user)),
                Err(sqlx::Error::RowNotFound) => Ok(None)
            }
        }
    ...

However, part of dependency-injection is to avoid injecting tight-coupled dependencies into other services. When I try to build the CoreAuthService from above, the DataLayer I want to inject now requires a definition of the Error type. I thought I could maybe just use the same Send + Sync requirement:

    #[derive(Clone)]
    pub struct CoreAuthService {
        data_layer: Arc<dyn DataLayer<Error = dyn Error + Send + Sync>>,
        token_service: Arc<dyn TokenService>,
    }

    impl CoreAuthService {
        pub fn new(
            data_layer: Arc<dyn DataLayer<Error = dyn Error + Send + Sync>>,
            token_service: Arc<dyn TokenService>,
        ) -> Self {
            Self {
                data_layer,
                token_service,
            }
        }
    }

However, then I run into the following compiler error whenever I use the DataLayer methods in the CoreAuthService methods:

the size for values of type `(dyn StdError + Send + Sync + 'static)` cannot be known at compilation time
the trait `Sized` is not implemented for `(dyn StdError + Send + Sync + 'static)`

I'm not sure how to proceed from here. How could I allow a generic Error type for the injected DataLayer but also make the compiler happy?

I'm also wondering if maybe I'm approaching the infrastructure of my codebase badly altogether (as I said, I'm a C# developer, so I'm importing some of the general best practices into Rust).


Solution

  • After some great input from Shaun the Sheep and Chayim Friedman, I learned what the issue was with this implementation: I was trying to build the DataLayer trait to not have any kind of implementation details required to inject it into other services (which is considered best practice for dependency-injection). However, by creating a type Error in the trait, which needed to be defined during injection, I was forcing implementation details into the definition.

    Instead of defining an Error type in the trait, I needed my DataLayer methods to return a generic Error type that could work for whatever implementations were expected. As Axum (Tower) has a built in type for this - BoxError - I opted in for using that:

    pub type Result<T> = std::result::Result<T, BoxError>;
    
    #[async_trait]
    pub trait DataLayer : Send + Sync {
        async fn get_user_by_id<'a>(&self, user_id: &'a str) -> Result<Option<UserDbModel>>;
        async fn get_user_by_email<'a>(&self, email: &'a str) -> Result<Option<UserDbModel>>;
    
        async fn get_refr_token_by_token<'a>(&self, token: &'a str) -> Result<Option<RefrTokenDbModel>>;
        async fn get_refr_token_by_id(&self, token: i32) -> Result<Option<RefrTokenDbModel>>;
        async fn create_refr_token(&self, refr_token: CreateRefrTokenDbModel) -> Result<i32>;
        async fn revoke_refr_token(&self, token: RevokeRefrTokenDbModel) -> Result<()>;
    }
    

    My implementation now looks like this:

    #[async_trait]
    impl DataLayer for DbDataLayer {
        async fn get_user_by_id<'a>(&self, user_id: &'a str) -> Result<Option<UserDbModel>> {
        let user = sqlx::query_as!(UserDbModel, r"
            SELECT id, email, password_hash as pwd_hash, role FROM users
            WHERE id = ?
        ", user_id)
            .fetch_one(&self.db).await;
    
        match user {
            Ok(user) => Ok(Some(user)),
            Err(sqlx::Error::RowNotFound) => Ok(None),
            Err(e) => Err(Box::new(e))
        }
    }
    ...
    

    The nice thing about this structure is it's pretty simple to convert it to further down-the-pipeline Error types. For example:

    #[derive(Debug, Error)]
    pub enum LoginError {
        #[error("An internal server error has occurred")]
        DataLayerError(BoxError),
        #[error("The given email {0} doesn't exist")]
        EmailDoesNotExist(String),
        #[error("Password does not match for email {0}")]
        PasswordDoesNotMatch(String),
    }
    
    impl Into<LoginError> for BoxError {
        fn into(self) -> LoginError {
            LoginError::DataLayerError(self)
        }
    }
    
    impl IntoResponse for LoginError {
        fn into_response(self) -> Response {
            return if let LoginError::DataLayerError(err) = &self {
                error!("{:?}", err);
                (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
            } else {
                (StatusCode::BAD_REQUEST, self.to_string()).into_response()
            }
        }
    }
    

    Then the DataLayer can be used in the following way:

    #[async_trait]
    impl AuthService for CoreAuthService {
        async fn try_accept_creds(&self, payload: LoginPayload) -> auth_error::Result<TokensModel> {
            // Get the user associated with the email (if exists)
            let user = self.data_layer.get_user_by_email(&payload.email).await
                .map_err(|e| e.into())?;
    
            if let Some(user) = user {
    ...
    

    This is a bit verbose, but it's definitely worth ensuring errors are properly handled! (The unexpected nature of C# exceptions are easily my least favorite thing about the language).

    The only other consideration in this matter is whether BoxError is a good use for the Result type the DataLayer methods output, or whether it would be better to use a type more defined to the DataLayer trait itself. For now, I decided that BoxError is suitable for the size of the project, as I don't anticipate the services DataLayer injects into to return a BoxError for any other reason. It might be beneficial to reexamine this later though.

    This is my first StackOverflow post ever, and it has been incredibly helpful, thanks again to everyone involved!