Search code examples
matplotlibrandom-walk

viewing 2D DNA walk with different colours


I am interested in creating a form of RandomWalk, using DNA sequence to create the walk (eg T = up, A = down etc). I have created the code, however i am wanting to know if it is possible for each of the 4 base letters to be assigned a colour instead of the final plot graph only being in one colour?

import matplotlib.pyplot as plt 

x = y = 0

x_values = [0]
y_values = [0]

dna_seq =  ('GGACTTCCCTATGGTGCTAACAAAGAGGCAGACAAA')


for base in dna_seq:
    if base == 'T':
        y += 1
    elif base == 'A':
        y -= 1
    elif base == 'G':
        x += 1
    elif base == 'C':
        x -= 1 
    x_values.append(x)
    y_values.append(y)
    
            

fig, ax = plt.subplots()
ax.plot(x_values, y_values, c='g')
plt.show()    

Solution

  • A multicolored line based on this example can be used. The idea is to split the line into sequences and then plot the lines using a LineCollection. Each line of the collection can have is own color.

    As the random walker uses a few segments more than once, some of the segments have to be shifted a bit.

    import matplotlib.pyplot as plt 
    
    x = y = 0.
    x_values = [0.]
    y_values = [0.]
    colors = []
    
    dna_seq =  ('GGACTTCCCTATGGTGCTAACAAAGAGGCAGACAAA')#
    color_lookup = {'A': 'red',
             'T':'green',
             'G': 'blue',
             'C': 'orange'}
    
    for base in dna_seq:
        if base == 'T':
            y += 1
        elif base == 'A':
            y -= 1
        elif base == 'G':
            x += 1
        elif base == 'C':
            x -= 1 
        x_values.append(x)
        y_values.append(y)
        colors.append(color_lookup[base])
    
    import numpy as np
    from matplotlib.collections import LineCollection
    
    points = np.array([x_values, y_values]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    
    #handle collisions (algorithm could probably be improved :-) )
    delta = 0.1
    changed=True
    while changed:
        changed=False
        for idx, segment in enumerate(segments):
            if idx > 0:
                cur_seg = segment.tolist()
                if cur_seg in segments[:idx-1].tolist() or [cur_seg[1], cur_seg[0]] in segments[:idx].tolist():
                    if(cur_seg[0][0] == cur_seg[1][0]):
                        segment[0][0] += delta
                        segment[1][0] += delta
                    else:
                        segment[0][1] += delta
                        segment[1][1] += delta
                    changed=True
    
    fig, ax = plt.subplots()
    lc = LineCollection(segments, colors=colors)
    lc.set_linewidth(2)
    ax.add_collection(lc)
    
    ax.set_aspect('equal')
    ax.set_xlim(min(x_values)-.1, max(x_values)+.1)
    ax.set_ylim(min(y_values)-.1, max(y_values)+.1)
    
    plt.show() 
    

    enter image description here