Search code examples
swiftasync-awaitswift-concurrency

Have multiple async code waiting for the same task to finish in Swift


I have a generic request async method that I use to make all the network requests of my app.

The first thing this method does it to check if the access token is still valid, and if it's not, the token is refreshed. However, this accessTokenRefreshableService can only be called once because the refreshToken it uses is only valid once.

In order to ensure accessTokenRefreshableService is only executed once, and that all the request wait until the accessTokenRefreshableService has finished I have added a semaphore.wait() and semaphore.signal(). But it does feel like a wrong solution, plus using it gives the following warning:

Instance method 'wait' is unavailable from asynchronous contexts; Await a Task handle instead; this is an error in Swift 6

What would be the proper way to get this behavior?

func request<T: Decodable, S: APIServiceProtocol>(service: S) async throws -> T {
    // Access token
    semaphore.wait()
    if let accessToken = service.accessToken {
        if JWT.isTokenExpired(accessToken) {
            try await accessTokenRefreshableService.request()
        }
    }
    semaphore.signal()
    
    // Prepare request
    // Make request
    // Handle result
}

Solution

  • I would redesign this as an actor that always give you a valid access token.

    actor AccessTokenSource {
        private var accessToken = "Some Invalid Token" // perhaps read this from UserDefalts
        
        var refreshTask: Task<String, Error>?
        
        func get() async throws -> String {
            if let refreshTask {
                return try await refreshTask.value
            }
            if JWT.isTokenExpired(accessToken) {
                let task = Task {
                    // request should be changed to return the new token
                    return try await accessTokenRefreshableService.request()
                }
                refreshTask = task
                let newToken = try await task.value
                accessToken = newToken
            }
            return accessToken
        }
    
        // if you really want to use user defaults
        func updateUserDefaults() async throws {
            UserDefaults.standard.set(try await get(), forKey: ...)
        }
    }
    

    If the token has expired, the actor requests a new one, and importantly, sets the refreshTask. This indicates that a new token is being requested. If another call to get finds a non-nil refreshTask, it waits for the result of that task, instead of requesting a new one.

    Make sure you only have one instance of this AccessTokenSource.

    let accessTokens = AccessTokenSource()
    

    Now you just need to call this in request(service:)

    func request<T: Decodable, S: APIServiceProtocol>(service: S) async throws -> T {
        var token = try await accessTokens.get()
        // or try await accessTokens.updateUserDefaults()
    
        // the rest of the things that this function should do...
    }
    

    Minimal Example:

    actor AccessTokenSource {
        private var token = "Foo"
        
        var refreshTask: Task<String, Error>?
        
        func get() async throws -> String {
            if let refreshTask {
                return try await refreshTask.value
            }
            if !token.hasPrefix("Bar") { // let's suppose this means the token is expired
                refreshTask = Task {
                    // simulating getting the new token...
                    try await Task.sleep(for: .seconds(1))
    
                    // give a random number to the new token
                    return "Bar\(Int.random(in: 0..<10000))"
                }
                let newToken = try await refreshTask!.value
                token = newToken
            }
            return token
        }
        
    }
    
    let accessTokens = AccessTokenSource()
    
    func request() async throws {
        let x = try await accessTokens.get()
        print(x) // this prints the same token for all the calls, meaning that only one request is sent
    }
    
    for _ in 0..<100 {
        Task {
            try await request()
        }
    }