Search code examples
pythonsnakemake

split bam by clusters and then merge bam by cluster using checkpoint


I have three single cell bam files from 3 different samples that I need to split to smaller bams by clusters. I then need to merge the bam files from different samples for the same clusters. I tried using checkpoint but kind of lost. https://snakemake.readthedocs.io/en/stable/snakefiles/rules.html

It is a continuation of this question I posted split bam files to (variable) pre-defined number of small bam files depending on the sample

SAMPLE_cluster = { "SampleA" : [ "1", "2", "3" ], "SampleB" :  [ "1" ], "SampleC" : [ "1", "2" ] }

CLUSTERS = []
for sample in SAMPLE_cluster:
    CLUSTERS.extend(SAMPLE_cluster[sample])
CLUSTERS = sorted(set(CLUSTERS)

rule all:
    input: expand("01merged_bam/{cluster_id}.bam, cluster_id = CLUSTERS)

checkpoint split_bam:
    input: "{sample}.bam"
    output: directory("01split_bam/{sample}/")
    shell:
       """
       split_bam.sh {input} 
       """
## the split_bam.sh will split the bam file to "01split_bam/{sample}/{sample}_{cluster_id}.bam" 

def merge_bam_input(wildcards):
    checkpoint_output = checkpoints.split_bam.get(**wildcards).output[0]
    return expand("01split_bam/{sample}/{sample}_{{cluster_id}}.bam", \
                sample = glob_wildcards(os.path.join(checkpoint_output, "{sample}_{cluster_id}.bam")).sample)


rule merge_bam_per_cluster:
    input: merge_bam_input
    output: "01merged_bam/{cluster_id}.bam"
    log: "00log/{cluster_id}.merge_bam.log"
    threads: 2
    shell:
        """
        samtools merge -@ 2 -r {output} {input}
        """


depending on the cluster number, the input of rule merge_bam_per_cluster will change:

e.g. for cluster 1: "01split_bam/SampleA/SampleA_1.bam", "01split_bam/SampleB/SampleB_1.bam", "01split_bam/SampleC/SampleC_1.bam".

for cluster 2: "01split_bam/SampleA/SampleA_2.bam", "01split_bam/SampleC/SampleC_2.bam".

for cluster 3: "01split_bam/SampleA/SampleA_3.bam".


Solution

  • I decided to not using checkpoint and use an input function to get the input for

    
    SAMPLE_cluster = { "SampleA" : [ "1", "2", "3" ], "SampleB" :  [ "1" ], "SampleC" : [ "1", "2" ] }
    
    # reverse the mapping
    cluster_sample = {'1':['sampleA','sample'B','sampleC'], '2':['sampleA', 'sampleC'], '3':['sampleA']}
    
    rule split_bam:
        input: "{sample}.bam"
        output: "split.touch"
        shell:
           """
           split_bam {input} 
           touch split.touch
           """
    rule index_split_bam:
        input: "split.touch"
        output: "split_bam/{sample}_{cluster_id}.bam.bai"
        shell:
            """
            samtools index 01split_bam/{wildcards.sample}/{wildcards.sample}_{wildcards.cluster_id}.bam
            """
    
    def get_merge_bam_input(wildcards):
        samples = cluster_sample[wildcards.cluster_id]
        return expand("01split_bam/{sample}/{sample}_{{cluster_id}}.bam.bai", sample = samples)
    
    
    rule merge_bam_per_cluster:
        input: get_merge_bam_input
        output: "01merged_bam/{cluster_id}.bam"
        params:
                bam = lambda wildcards, input: " ".join(input).replace(".bai", "")
        log: "00log/{cluster_id}.merge_bam.log"
        threads: 2
        shell:
            """
            samtools merge -@ 2 -r {output} {params.bam}
            """
    

    it seems to be working.