Search code examples
rustpaginationrust-diesel

is it possible to get pagination info from rust pagination query result


I am using rust diesel to do a pagination query right now. I get the pagination information from request right now, I was wonder is it possible to get pagination information from Query result? This is my code:

pub fn fav_music_query<T>(request: Json<FavMusicRequest>) -> Paginated<Vec<Favorites>> {
    use crate::model::diesel::rhythm::rhythm_schema::favorites::dsl::*;
    let connection = config::establish_music_connection();
    let query = favorites.filter(like_status.eq(1)).paginate(request.pageNum).per_page(request.pageSize);
    let query_result = query.load_and_count_pages::<Favorites>(&connection).unwrap();
    let page_result = Paginated{
        query: query_result.0,
        page: request.pageNum,
        per_page: request.pageSize,
        is_sub_query: false
    };
    return page_result;
}

from the request I could only get pageSize and pageNum, but I did not know the total size. what is the best way the get the pagination information? this is my pagination code:

use diesel::prelude::*;
use diesel::query_dsl::methods::LoadQuery;
use diesel::query_builder::{QueryFragment, Query, AstPass};
use diesel::pg::Pg;
use diesel::sql_types::BigInt;
use diesel::QueryId;
use serde::{Serialize, Deserialize};

pub trait PaginateForQueryFragment: Sized {
    fn paginate(self, page: i64) -> Paginated<Self>;
}

impl<T> PaginateForQueryFragment for T
    where T: QueryFragment<Pg>{
    fn paginate(self, page: i64) -> Paginated<Self> {
        Paginated {
            query: self,
            per_page: 10,
            page,
            is_sub_query: true,
        }
    }
}

#[derive(Debug, Clone, Copy, QueryId, Serialize, Deserialize, Default)]
pub struct Paginated<T> {
    pub query: T,
    pub page: i64,
    pub per_page: i64,
    pub is_sub_query: bool
}

impl<T> Paginated<T> {
    pub fn per_page(self, per_page: i64) -> Self {
        Paginated { per_page, ..self }
    }

    pub fn load_and_count_pages<U>(self, conn: &PgConnection) -> QueryResult<(Vec<U>, i64)>
        where
            Self: LoadQuery<PgConnection, (U, i64)>,
    {
        let per_page = self.per_page;
        let results = self.load::<(U, i64)>(conn)?;
        let total = results.get(0).map(|x| x.1).unwrap_or(0);
        let records = results.into_iter().map(|x| x.0).collect();
        let total_pages = (total as f64 / per_page as f64).ceil() as i64;
        Ok((records, total_pages))
    }
}

impl<T: Query> Query for Paginated<T> {
    type SqlType = (T::SqlType, BigInt);
}

impl<T> RunQueryDsl<PgConnection> for Paginated<T> {}


impl<T> QueryFragment<Pg> for Paginated<T>
    where
        T: QueryFragment<Pg>,
{
    fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> {
        out.push_sql("SELECT *, COUNT(*) OVER () FROM ");
        if self.is_sub_query {
            out.push_sql("(");
        }
        self.query.walk_ast(out.reborrow())?;
        if self.is_sub_query {
            out.push_sql(")");
        }
        out.push_sql(" t LIMIT ");
        out.push_bind_param::<BigInt, _>(&self.per_page)?;
        out.push_sql(" OFFSET ");
        let offset = (self.page - 1) * self.per_page;
        out.push_bind_param::<BigInt, _>(&offset)?;
        Ok(())
    }
}

#[derive(Debug, Clone, Copy, QueryId)]
pub struct QuerySourceToQueryFragment<T> {
    query_source: T,
}

impl<FC, T> QueryFragment<Pg> for QuerySourceToQueryFragment<T>
    where
        FC: QueryFragment<Pg>,
        T: QuerySource<FromClause=FC>,
{
    fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> {
        self.query_source.from_clause().walk_ast(out.reborrow())?;
        Ok(())
    }
}

pub trait PaginateForQuerySource: Sized {
    fn paginate(self, page: i64) -> Paginated<QuerySourceToQueryFragment<Self>>;
}

impl<T> PaginateForQuerySource for T
    where T: QuerySource {
    fn paginate(self, page: i64) -> Paginated<QuerySourceToQueryFragment<Self>> {
        Paginated {
            query: QuerySourceToQueryFragment {query_source: self},
            per_page: 10,
            page,
            is_sub_query: false,
        }
    }
}

and this is my FavMusicRequest that define the pagination query information:

#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)]
#[allow(non_snake_case)]
pub struct FavMusicRequest {
    pub userId: i64,
    pub pageNum: i64,
    pub pageSize: i64
}

and this is the query entity from database:

#[derive( Serialize, Queryable, Deserialize,Default)]
pub struct Favorites {
    pub id: i64,
    pub song_id: Option<i64>,
    pub created_time: i64,
    pub updated_time: i64,
    pub user_id: i64,
    pub source_id: String,
    pub like_status: i32,
    pub source: i32,
    pub playlist_id: i64,
    pub play_count: i32,
    pub fetched_download_url: Option<i32>,
    pub downloaded: Option<i32>
}

and this is the request defined with rocket:

#[post("/v1/page",data = "<request>")]
pub fn page(request: Json<FavMusicRequest>) -> content::Json<String> {
    let fav_musics = fav_music_query::<Vec<Favorites>>(request);
    let res = ApiResponse {
        result: fav_musics,
        ..Default::default()
    };
    let response_json = serde_json::to_string(&res).unwrap();
    return content::Json(response_json);
}

Solution

  • You can return the results in QueryResult<(....)> in the load_and_count_pages.

    Here is a working example:

    use diesel::pg::Pg;
    use diesel::prelude::*;
    use diesel::query_builder::*;
    use diesel::query_dsl::methods::LoadQuery;
    use diesel::sql_types::BigInt;
    
    pub trait Paginate: Sized {
        fn paginate(self, page: i64) -> Paginated<Self>;
    }
    
    impl<T> Paginate for T {
        fn paginate(self, page: i64) -> Paginated<Self> {
            Paginated {
                query: self,
                per_page: DEFAULT_PER_PAGE,
                page,
            }
        }
    }
    
    const DEFAULT_PER_PAGE: i64 = 100;
    
    #[derive(Debug, Clone, Copy, QueryId)]
    pub struct Paginated<T> {
        query: T,
        page: i64,
        per_page: i64,
    }
    
    impl<T> Paginated<T> {
        pub fn per_page(self, per_page: i64) -> Self {
            Paginated { per_page, ..self }
        }
    
        pub fn load_and_count_pages<U>(self, conn: &PgConnection) -> QueryResult<(Vec<U>, i64)>
        where
            Self: LoadQuery<PgConnection, (U, i64)>,
        {
            let _per_page = self.per_page;
            let results = self.load::<(U, i64)>(conn)?;
            let total = results.get(0).map(|x| x.1).unwrap_or(0);
            let records = results.into_iter().map(|x| x.0).collect();
            Ok((records, total))
        }
    }
    
    impl<T: Query> Query for Paginated<T> {
        type SqlType = (T::SqlType, BigInt);
    }
    
    impl<T> RunQueryDsl<PgConnection> for Paginated<T> {}
    
    impl<T> QueryFragment<Pg> for Paginated<T>
    where
        T: QueryFragment<Pg>,
    {
        fn walk_ast(&self, mut out: AstPass<Pg>) -> QueryResult<()> {
            out.push_sql("SELECT *, COUNT(*) OVER () FROM (");
            self.query.walk_ast(out.reborrow())?;
            out.push_sql(") t LIMIT ");
            out.push_bind_param::<BigInt, _>(&self.per_page)?;
            out.push_sql(" OFFSET ");
            let offset = (self.page - 1) * self.per_page;
            out.push_bind_param::<BigInt, _>(&offset)?;
            Ok(())
        }
    }
    

    Then you can define a pagination handler.

    #[derive(Serialize, Deserialize)]
    #[serde(rename_all = "camelCase")]
    pub struct Cursor {
        pub total_pages: i32,
        pub filters: Vec<String>,
    }
    
    #[derive(Serialize)]
    pub struct PaginationResult<T>
    where
        T: Serialize,
    {
        pub records: Vec<T>,
        pub cursor: Cursor,
    }
    

    Then you can now do:

    pub fn fav_music_query<T>(request: Json<FavMusicRequest>) -> PaginationResult<Vec<Favorites>> {
        use crate::model::diesel::rhythm::rhythm_schema::favorites::dsl::*;
        let connection = config::establish_music_connection();
        let query = favorites
            .filter(like_status.eq(1))
            .paginate(request.pageNum)
            .per_page(request.pageSize);
        let (favorites, pages) = query
            .load_and_count_pages::<Favorites>(&connection)
            .unwrap();
        return PaginationResult {
            records: favorites,
            cursor: Cursor {
                total_pages: pages as i32,
                filters: vec![],
            },
        };
    }
    

    My example just gets the total number of pages. You can add more all you need.