Search code examples
pythonoptuna

How to sample parameters without duplicates in optuna?


I am using optuna for parameter optimisation of my custom models.

Is there any way to sample parameters until current params set was not tested before? I mean, do try sample another params if there were some trial in the past with the same set of parameters.

In some cases it is impossible, for example, when there is categorial distribution and n_trials is greater than number os possible unique sampled values.

What I want: have some config param like num_attempts in order to sample parameters up to num_attempts in for-loop until there is a set that was not tested before, else - to run trial on the last sampled set.

Why I need this: just because it costs too much to run heavy models several times on the same parameters.

What I do now: just make this "for-loop" thing but it's messy.

If there is another smart way to do it - will be very grateful for information.

Thanks!


Solution

  • To the best of my knowledge, there is no direct way to handle your case for now. As a workaround, you can check for parameter duplication and skip the evaluation as follows:

    import optuna
    
    def objective(trial: optuna.Trial):
        # Sample parameters.
        x = trial.suggest_int('x', 0, 10)
        y = trial.suggest_categorical('y', [-10, -5, 0, 5, 10])
    
        # Check duplication and skip if it's detected.
        for t in trial.study.trials:
            if t.state != optuna.structs.TrialState.COMPLETE:
                continue
    
            if t.params == trial.params:
                return t.value  # Return the previous value without re-evaluating it.
    
                # # Note that if duplicate parameter sets are suggested too frequently,
                # # you can use the pruning mechanism of Optuna to mitigate the problem.
                # # By raising `TrialPruned` instead of just returning the previous value,
                # # the sampler is more likely to avoid sampling the parameters in the succeeding trials.
                #
                # raise optuna.structs.TrialPruned('Duplicate parameter set')
    
        # Evaluate parameters.
        return x + y
    
    # Start study.
    study = optuna.create_study()
    
    unique_trials = 20
    while unique_trials > len(set(str(t.params) for t in study.trials)):
        study.optimize(objective, n_trials=1)