Search code examples
androidandroid-viewmodeldagger-hilt

Can Hilt be used on Android with by viewModels to initialize an abstract viewModel field?


I'm trying to wrap my head around Hilt and the way it deals with ViewModels. I would like my fragments to depend on abstract view models, so I can easily mock them during UI tests. Ex:

@AndroidEntryPoint
class MainFragment : Fragment() {
    private val vm : AbsViewModel by viewModels()

    /*
    ...
    */
}

@HiltViewModel
class MainViewModel(private val dependency: DependencyInterface) : AbsViewModel()

abstract class AbsViewModel : ViewModel()

Is there a way to configure by viewModels() so that it can map concrete implementations to abstract view models? Or pass a custom factory producer to viewModels() that can map concrete view models instances to abstract classes?

The exact question is also available here, but it is quite old considering hilt was still in alpha then: https://github.com/google/dagger/issues/1972 However, the solution provided there is not very desirable since it uses a string that points to the path of the concrete view model. I think this will not survive obfuscation or moving files and it can quickly become a nightmare to maintain. The answer also suggests injecting a concrete view model into the fragment during tests with all the view model's dependencies mocked, thus gaining the ability to control what happens in the test. This automatically makes my UI test depend on the implementation of said view model, which I would very much like to avoid.

Not being able to use abstract view models in my fragments makes me think I'm breaking the D in SOLID principles, which is something that I would also like to avoid.


Solution

  • Not the cleanest solution, but here's what I managed to do.

    First create a ViewModelClassesMapper to help map an abstract class to a concrete one. I'm using a custom AbsViewModel in my case, but this can be swapped out for the regular ViewModel. Then create a custom view model provider that depends on the above mapper.

    class VMClassMapper @Inject constructor (private val vmClassesMap: MutableMap<Class<out AbsViewModel>, Provider<KClass<out AbsViewModel>>>) : VMClassMapperInterface {
        @Suppress("TYPE_INFERENCE_ONLY_INPUT_TYPES_WARNING")
        override fun getConcreteVMClass(vmClass: Class<out AbsViewModel>): KClass<out AbsViewModel> {
            return vmClassesMap[vmClass]?.get() ?: throw Exception("Concrete implementation for ${vmClass.canonicalName} not found! Provide one by using the @ViewModelKey")
        }
    }
    
    interface VMClassMapperInterface {
        fun getConcreteVMClass(vmClass: Class<out AbsViewModel>) : KClass<out AbsViewModel>
    }
    
    interface VMDependant<VM : AbsViewModel> : ViewModelStoreOwner {
        fun getVMClass() : KClass<VM>
    }
    
    class VMProvider @Inject constructor(private val vmMapper: VMClassMapperInterface) : VMProviderInterface {
        @Suppress("UNCHECKED_CAST")
        override fun <VM : AbsViewModel> provideVM(dependant: VMDependant<VM>): VM {
            val concreteClass = vmMapper.getConcreteVMClass(dependant.getVMClass().java)
            return ViewModelProvider(dependant).get(concreteClass.java) as VM
        }
    }
    
    interface VMProviderInterface {
        fun <VM :AbsViewModel> provideVM(dependant: VMDependant<VM>) : VM
    }
    
    @Module
    @InstallIn(SingletonComponent::class)
    abstract class ViewModelProviderModule {
    
        @Binds
        abstract fun bindViewModelClassesMapper(mapper: VMClassMapper) : VMClassMapperInterface
    
        @Binds
        @Singleton
        abstract fun bindVMProvider(provider: VMProvider) : VMProviderInterface
    
    }
    

    Then, map your concrete classes using the custom ViewModelKey annotation.

    @Target(
            AnnotationTarget.FUNCTION,
            AnnotationTarget.PROPERTY_GETTER,
            AnnotationTarget.PROPERTY_SETTER
    )
    @kotlin.annotation.Retention(AnnotationRetention.RUNTIME)
    @MapKey
    annotation class ViewModelKey(val value: KClass<out AbsViewModel>)
    
    @Module
    @InstallIn(SingletonComponent::class)
    abstract class ViewModelsDI {
    
        companion object {
    
            @Provides
            @IntoMap
            @ViewModelKey(MainContracts.VM::class) 
            fun provideConcreteClassForMainVM() : KClass<out AbsViewModel> = MainViewModel::class
    
            @Provides
            @IntoMap
            @ViewModelKey(SecondContracts.VM::class)
            fun provideConcreteClassForSecondVM() : KClass<out AbsViewModel> = SecondViewModel::class
        }
    
    }
    
    interface MainContracts {
    
        abstract class VM : AbsViewModel() {
            abstract val textLiveData : LiveData<String>
            abstract fun onUpdateTextClicked()
            abstract fun onPerformActionClicked()
        }
    
    }
    
    interface SecondContracts {
    
        abstract class VM : AbsViewModel()
    
    }
    

    Finally, your fragment using the abstract view model looks like this:

    @AndroidEntryPoint
    class MainFragment : Fragment(), VMDependant<MainContracts.VM> {
    
        @Inject lateinit var vmProvider: VMProviderInterface
    
        protected lateinit var vm : MainContracts.VM
    
        override fun onCreate(savedInstanceState: Bundle?) {
            super.onCreate(savedInstanceState)
            vm = vmProvider.provideVM(this)
        }
    
        override fun getVMClass(): KClass<MainContracts.VM> = MainContracts.VM::class
    
    }
    

    It's a long way to go, but after you have the initial setup is completed, all you need to do for individual fragments is to make them implement VMDependant and provide a concrete class for YourAbsViewModel in Hilt using the @ViewModelKey.

    In tests, vmProvider can then be easily mocked and forced to do your bidding.