Search code examples
unit-testinggomicrosoft-graph-api

Mock MS Graph SDK in Golang unit test?


I'm presently building a service to allow users to provide themselves with self-service access to certain applications, with access to those applications managed via Entra groups.

I have the following golang code to check if the user is a member of a specified group (this is my first real working with Golang and the associated Graph SDK, so apologies if this sucks!)

package entra

import (    
    "context"
    
    azidentity "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
    msgraphsdk "github.com/microsoftgraph/msgraph-sdk-go"
    "github.com/microsoftgraph/msgraph-sdk-go/models"
    "github.com/microsoftgraph/msgraph-sdk-go/models/odataerrors"
    "github.com/spf13/viper"
    "go.uber.org/zap"
)

type Credentialer interface {
    NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *azidentity.ClientSecretCredentialOptions) (*azidentity.ClientSecretCredential, error)
}

type GraphClientCreator interface {
    NewGraphServiceClientWithCredentials(cred *azidentity.ClientSecretCredential, scopes []string) (*msgraphsdk.GraphServiceClient, error)
}

type Service struct {
    Credentialer       Credentialer
    GraphClientCreator GraphClientCreator
}

type AzureCredentialer struct{}

func (ac *AzureCredentialer) NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *azidentity.ClientSecretCredentialOptions) (*azidentity.ClientSecretCredential, error) {
    return azidentity.NewClientSecretCredential(tenantID, clientID, clientSecret, options)
}

type MsGraphClientCreator struct{}

func (mgcc *MsGraphClientCreator) NewGraphServiceClientWithCredentials(cred *azidentity.ClientSecretCredential, scopes []string) (*msgraphsdk.GraphServiceClient, error) {
    return msgraphsdk.NewGraphServiceClientWithCredentials(cred, scopes)
}


func (s *Service) GetGraphClient() (*msgraphsdk.GraphServiceClient, error) {

    
    // Get the Azure AD client ID
    clientId := viper.GetString("client_id")
    tenantId := viper.GetString("tenant_id")
    clientSecret := viper.GetString("client_secret")

    // Create creds
    clientCredentials, err := s.Credentialer.NewClientSecretCredential(tenantId, clientId, clientSecret, nil)
    if err != nil {
        printOdataError(err)
        zap.S().Error("Error creating managed identity credentials: ", err)
        return nil, err
    }

    zap.S().Debug("Managed identity credentials created successfully")

    // Create a new Graph client
    graphClient, err := s.GraphClientCreator.NewGraphServiceClientWithCredentials(
        clientCredentials, 
        []string{"https://graph.microsoft.com/.default"})

    if err != nil {
        printOdataError(err)
        zap.S().Error("Error creating graph client: ", err)
        return nil, err
    }

    zap.S().Debug("Graph client created successfully")

    return graphClient, nil
}

func IsUserInGroup(groupId string, userId string) (bool, error) {

    service := &Service{
        Credentialer: &AzureCredentialer{},
        GraphClientCreator: &MsGraphClientCreator{},
    }
    
    // Get the graph client
    graphClient, err := service.GetGraphClient()
    if err != nil {
        printOdataError(err)
        zap.S().Error("Error getting graph client: ", err)
        return false, err
    }

    zap.S().Debug("Getting group members...")
    zap.S().Debug("Group ID: ", groupId)
    zap.S().Debug("User ID: ", userId)

    group, err := graphClient.Users().ByUserId(userId).MemberOf().Get(context.Background(), nil)
    if err != nil {
        printOdataError(err)
        zap.S().Error("Error getting group members: ", err)
        return false, err
    }

    zap.S().Debug("Group memberships: ", len(group.GetValue()))

    for _, membership := range group.GetValue() {
        if *membership.GetId() == groupId {
            zap.S().Debug("User is a member of the group")
            return true, nil
        }
    }

    return false, nil
}

However, when it comes to attempting to unit test it, I can't seem to figure out how mock the Graph responses accordingly and I'm not 100% sure where to begin.

If anybody could help to point me in the right direction, I'd be very grateful!


Solution

  • Thanks for the detailed example.

    Let's first see what you use. AFAICS the only point where you call *msgraphsdk.GraphServiceClient and use the response is

        group, err := graphClient.Users().ByUserId(userId).MemberOf().Get(context.Background(), nil)
        /// [...]
        for _, membership := range group.GetValue() {
            /// [...]
        }
    

    So, let us mock that. Since we want to test our business logic (is the user in a group or not), we write a short wrapper for the call:

    type MSGraphClient struct {
        c *msgraphsdk.GraphServiceClient
    }
    
    func (g MSGraphClient) UserGroupsByUserID(
        ctx context.Context, userID string,
    ) ([]models.DirectoryObjectable, error) {
        response, err := g.c.Users().ByUserId(userID).MemberOf().Get(ctx, nil)
        if err != nil {
            return nil, err
        }
    
        return response.GetValue(), nil
    }
    

    This transforms our code above into

        gc := MSGraphClient{c: graphClient}
        groups, err := gc.UserGroupsByUserID(context.Background(), userId)
        // [...]
        for _, membership := range groups {
            // [...]
        }
    

    Now that we simplified things, let's see what our code dependes on:

    type GraphClient interface {
        UserGroupsByUserID(ctx context.Context, userID string) ([]models.DirectoryObjectable, error)
    }
    

    Great. Let us refactor our IsUserInGroup so that it is testable, and inject the relevant collaborators:

    func IsUserInGroup(ctx context.Context, graphClient GraphClient, groupID string, userID string) (bool, error) {
        groups, err := graphClient.UserGroupsByUserID(ctx, userID)
        if err != nil {
            return false, fmt.Errorf("error getting groups for user %s: %w", userID, err)
        }
    
        for _, membership := range groups {
            if *membership.GetId() == groupID {
                return true, nil
            }
        }
    
        return false, nil
    }
    

    Now that is testable by writing your own TestGraphClient:

    type TestGraphClient struct{}
    
    func (TestGraphClient) UserGroupsByUserID(_ context.Context, userID string) ([]models.DirectoryObjectable, error) {
        if userID != "testUser" {
            err := odataerrors.NewODataError()
    
            return nil, err
        }
    
        var result []models.DirectoryObjectable
        result = append(result, TestDirectoryObjectable{id: "group1"})
    
        return result, nil
    }
    

    and using a TestDirectoryObjectable like:

    type TestDirectoryObjectable struct{ id string }
    
    func (t TestDirectoryObjectable) GetId() *string {
        return &t.id
    }
    
    // [...]
    
    func (TestDirectoryObjectable) OtherFuncs() {
        panic("unimplemented")
    }
    

    Your IDE should help with implementing the missing functions of models.DirectoryObjectable. If that's too much boilerplate, you could either let mockery write the code for you, or adapt MSGraphClient to return something more convenient - the latter meaning it has at least some logic which needs to be tested, while currently it is just a thin wrapper.