Search code examples
djangodatabase-migrationdjango-migrations

How do migration operations (eg. AddField) detect if you're applying or reverting a migration?


A good example is django.db.migrations.AddField.

Let's say I've created a simple migration such as:

from django.db import migrations, models


class Migration(migrations.Migration):
    dependencies = []

    operations = [
        migrations.AddField(
            model_name="foo",
            name="bar",
            field=models.TextField(blank=True, null=True),
        ),
    ]

Running the migration would result in the bar field being added, and reverting to a previous migration would result in the bar field being removed.

However when I look under the hood of the AddField class, I don't see the logic for detecting the direction of the applied migration.


class AddField(FieldOperation):
    field: Field[Any, Any] = ...
    preserve_default: bool = ...
    def __init__(
        self,
        model_name: str,
        name: str,
        field: Field[Any, Any],
        preserve_default: bool = ...,
    ) -> None: ...

I need to create a similar function for collected data that can apply/revert changes based on changes made to the model the data is based on. This function would be used inside migrations in a similar way to AddField, and needs to detect whether the user is applying or reverting a migration.


Solution

  • The migrate management command will detect if this is a new (i.e. forward) or previous (i.e. backward) migration, and call either the database_forwards or database_backwards method depending on the direction of the migration. The Operation subclass doesn't need to know the direction on init:

    class AddField(FieldOperation):
        ...
        def database_forwards(self, app_label, schema_editor, from_state, to_state):
            to_model = to_state.apps.get_model(app_label, self.model_name)
            if self.allow_migrate_model(schema_editor.connection.alias, to_model):
                from_model = from_state.apps.get_model(app_label, self.model_name)
                field = to_model._meta.get_field(self.name)
                if not self.preserve_default:
                    field.default = self.field.default
                schema_editor.add_field(
                    from_model,
                    field,
                )
                if not self.preserve_default:
                    field.default = NOT_PROVIDED
    
        def database_backwards(self, app_label, schema_editor, from_state, to_state):
            from_model = from_state.apps.get_model(app_label, self.model_name)
            if self.allow_migrate_model(schema_editor.connection.alias, from_model):
                schema_editor.remove_field(
                    from_model, from_model._meta.get_field(self.name)
                )
    

    This is partly explained in the documentation, specifically in reversing migrations and writing your own migration operations sections.