Search code examples
androidretrofitokhttprefresh-token

How to prevent parallel refresh token requests while using Retrofit/OkHttp's Authenticator?


I just stumbled upon an issue where I started sending too parallel refresh token requests to the backend server I built, which caused concurrency issues where there is a race condition in which all of these parallel requests are requesting and updating different refresh tokens at the same time.

The only solution I came up with is to use a StateFlow, a Channel and an unscoped IO coroutine to observe the refresh state so that only the first refresh token request succeeds, and while it's refreshing, the other parallel requests are blocked observing until they get a signal from the first refresh token request to use the new token.

It works, but I'm new to Kotlin and its coroutine APIs aand it looks hacky, I can't help it but think there's definitenly a more sensisble way to approach this.

class MyAuthenticator @Inject constructor(
    private val refreshTokenUseCase: RefreshTokenUseCase,
    private val sharedPrefs: SharedPreferences
) : Authenticator {

    private val isRefreshingToken = MutableStateFlow(false)
    private val newRequest = Channel<Request>()

    override fun authenticate(route: Route?, response: Response): Request? {

        // logic to handle blocking parallel refresh token requests to wait for the first refresh token request to use it instead of useless api calls:
        if (isRefreshingToken.value) {
            CoroutineScope(Dispatchers.IO).launch {
                isRefreshingToken.collect { isRefreshingToken ->
                    if (!isRefreshingToken) {
                        val newToken = sharedPrefs.getToken().orEmpty()
                        val req = response.request.newBuilder()
                            .header("Authorization", "Bearer $newToken")
                            .build()
                        newRequest.send(req)
                    }
                }
            }
            return runBlocking(Dispatchers.IO) {
                newRequest.receive()
            }
        }

        isRefreshingToken.value = true

        // logic to handle refreshing the token
        runBlocking(Dispatchers.IO) {
            refreshTokenUseCase() // internally calls refresh token api then saves the token to shared prefs
        }.let { result ->
            isRefreshingToken.value = false
            return if (result.isSuccess) {
                val newToken = sharedPrefs.getToken().orEmpty()
                response.request.newBuilder()
                    .header("Authorization", "Bearer $newToken")
                    .build()
            } else {
                // logic to handle failure (logout, etc)
                null
            }
        }

    }
}

I searched all over stack overflow and while I've found many suggested solutions, none of them actually worked, half of which suggested using synchronization to force the parallel to start in an ordered manner, which still wastefully calls the API for a refresh token far too many times.


Solution

  • Ended up synchronizing the authenticate() method block with @Synchronized while also checking whether the request's header token is different to the locally persisted token to know whether it has already been refreshed or not. Works like a charm. Just make sure to make your refresh token api calls blocking on the background thread (e.g. runBlocking(Dispatchers.IO)) and to also use .commit() instead of .async() when updating the access token in your shared preferences.

    class MyAuthenticator @Inject constructor(
        private val refreshTokenUseCase: RefreshTokenUseCase,
        private val sharedPrefs: SharedPreferences
    ) : Authenticator {
    
        @Synchronized // annotate with @Synchronized to force parallel threads/coroutines to block and wait in an ordered manner when accessing authenticate()
        override fun authenticate(route: Route?, response: Response): Request? {
    
        // prevent parallel refresh requests
            val accessToken = sharedPrefs.getToken()
            val alreadyRefreshed = response.request.header("Authorization")?.contains(accessToken, true) == false
            if (alreadyRefreshed) { // if request's header's token is different, then that means the access token has already been refreshed and we return the response with the locally persisted token in the header 
            return response.request.newBuilder()
            .header("Authorization", "Bearer $accessToken")
            .build()
            }
    
            // logic to handle refreshing the token
            runBlocking(Dispatchers.IO) {
                refreshTokenUseCase() // internally calls refresh token api then saves the token to shared prefs synchronously
            }.let { result ->
                return if (result.isSuccess) {
                    val newToken = sharedPrefs.getToken().orEmpty()
                    response.request.newBuilder()
                        .header("Authorization", "Bearer $newToken")
                        .build()
                } else {
                    // logic to handle failure (logout, etc)
                    null
                }
            }
    
        }
    }