Search code examples
aggregate-functionswildcardsnakemake

Snakemake Checkpoints aggregate Skipping intermediate rules


I have a python script, which takes a bunch of fasta and gff files and gathers sequences based on similar COG IDs into individual directories within a Master COG directories. The number of COGs is dynamic, for which I'm using the checkpoints option in Snakemake.

The rule looks like the following:

checkpoint get_COG:
    input:
        rules.AMR_meta.output
    output:
        check=directory(os.path.join("COG_data"))
    threads:
        config['COG']['threads']
    log:
        os.path.join(RESULTS_DIR, "logs/COG_directory_setup.log")
    message:
        "Running python script to set up directory structure for GeneForest"
    run:
        import glob
        import pandas as pd
        import os
        import shutil
        import logging
        from Bio import SeqIO
        import argparse
        from io import StringIO
        import numpy as np
        from multiprocessing import Pool

        from scripts.utils import ParseGFF, GetAllCOG, CreateCOGDirs, GetSequence, GetCoverage, ProcessCOG, GetCoverageSums
        meta_file=pd.read_csv(input[0],sep=',')

        # List all COGs, create dirs
        cog_set=GetAllCOG(meta_file)
        CreateCOGDirs(cog_set)

        # Iterate over all COGs to gather the sequences
        print('Creating gene catalogue...')
        with Pool(threads) as p:
            p.map(ProcessCOG, [[cog, meta_file] for cog in list(cog_set)])

The output of this rule creates the following files:

COG_data/COGXXXX/COGXXXX_raw.fasta, COG_data/COGXXXX/COGXXXX_coverage.csv

I have subsequent rules, where I want to take the fasta output from the checkpoints rule and create some multiple sequence alignments and trees. They are as follows:

rule mafft:
    input:
        os.path.join("COG_data/{i}/{i}_raw.fasta")
    output:
        os.path.join("COG_data/{i}/{i}_aln.fasta")
    conda:
        os.path.join("envs/mafft.yaml")
    threads:
        config['MAFFT']['threads']
    log:
        os.path.join(RESULTS_DIR, "logs/{i}.mafft.log")
    message:
        "Getting multiple sequence alignment for each COG"
    shell:
        "(date && mafft --thread {threads} {input} > {output} && date) &> {log}"

rule trimal:
    input:
        os.path.join("COG_data/{i}/{i}_aln.fasta")
    output:
        os.path.join("COG_data/{i}/{i}_trim.fasta")
    conda:
        os.path.join("envs/trimal.yaml")
    log:
        os.path.join(RESULTS_DIR, "logs/{i}.trimal.log")
    message:
        "Getting trimmed alignment sequence for each COG"
    shell:
        "(date && trimal -in {input} -out {output} -automated1 && date) &> {log}"

rule iqtree:
    input:
        os.path.join("COG_data/{i}/{i}_trim.fasta")
    output:
        os.path.join("COG_data/{i}/{i}_trim.fasta.treefile")
    conda:
        os.path.join("envs/iqtree.yaml")
    log:
        os.path.join(RESULTS_DIR, "logs/{i}.iqtree.log")
    message:
        "Getting trees for each COG"
    shell:
        "(date && iqtree -s {input} -m MFP && date) &> {log}"

def COG_trees(wildcards):
    checkpoint_output= checkpoints.get_COG.get(**wildcards).output.check
    return expand(os.path.join("COG_data/{i}/{i}_trim.fasta.treefile"),
        i=glob_wildcards(os.path.join(checkpoint_output, "{i}_trim.fasta.treefile")).i)

rule trees:
    input:
        COG_trees
    output:
        os.path.join(RESULTS_DIR, "COG_trees.done")
    log:
        os.path.join(RESULTS_DIR, "logs/geneforest_is_ready.log")
    message:
        "Creates the COG trees via checkpoints"
    shell:
        "(date && touch {output} && date) &> {log}"

And though I got the original COG_data/COGXXXX/COGXXXX_raw.fasta files, the intermediate rules are not being run. The rest of the run jumps straight to rule trees and gives me the COG_trees.done output.

Is there a way to fix the deg COG_trees function get the run the intermediate rules?

Thank you for your help!


Solution

  • Turns out, that the aggregate function was wrong. Rather than calling the output from the very last rule, i.e. rule iqtree, the correct way of doing it was the following:

    def COG_trees(wildcards):
        checkpoint_output= checkpoints.get_COG.get(**wildcards).output.check
        return expand(os.path.join("COG_data/{i}/{i}_trim.fasta.treefile"),
            i=glob_wildcards(os.path.join(checkpoint_output, "{i}_raw.fasta")).i)
    

    Calling the output for the immediate rule after the checkpoint, i.e. rule mafft gave the expected output! :facepalm