Search code examples
simulationrunge-kutta

How to fix incorrect energy conservation problem in mass-spring-system simulation using RK4 method


I am making a simulation where you create different balls of certain mass, connected by springs which you can define (in the program below all springs have natural length L and spring constant k). How I do it is I created a function accel(b,BALLS), (note b is the specific ball and BALLS are all of the ball objects in various stages of update) which gets me acceleration on this one ball from calculating all the forces acting on it (tensions from ball the springs connected to it and gravity) and I would think this function is definitely correct and problems lie elsewhere in the while loop. I then use the RK4 method described on this website: http://spiff.rit.edu/richmond/nbody/OrbitRungeKutta4.pdf in the while loop to update velocity and position of each ball. To test my understanding of the method I first made a simulation where only two balls and one spring is involved on Desmos: https://www.desmos.com/calculator/4ag5gkerag I allowed for energy display and saw that indeed RK4 is much better than Euler method. Now I made it in python in the hope that it should work with arbitrary config of balls and springs, but energy isn't even conserved when I have two balls and one spring! I couldn't see what I did differently, at least when two balls on involved. And when I introduce a third ball and a second spring to the system, energy increases by the hundreds every second. This is my first time coding a simulation with RK4, and I expect you guys can find mistakes in it. I have an idea that maybe the problem is caused by because there are multiple bodies and difficulties arises when I update their kas or kvs at the same time but then again I can't spot any difference between what this code is doing when simulating two balls and my method used in the Desmos file. Here is my code in python:

    import pygame
    import sys
    import math
    import numpy as np
    
    
    pygame.init()
    width = 1200
    height = 900
    SCREEN = pygame.display.set_mode((width, height))
    font = pygame.font.Font(None, 25)
    TIME = pygame.time.Clock()
    
    dampwall = 1
    dt = 0.003
    g = 20
    k=10
    L=200
    
    
    def dist(a, b):
        return math.sqrt((a[0] - b[0])*(a[0] - b[0]) + (a[1] - b[1])*(a[1] - b[1]))
    
    
    def mag(a):
        return dist(a, [0, 0])
    
    def dp(a, b):
        return a[0]*b[0]+a[1]*b[1]
    
    
    def norm(a):
        return list(np.array(a)/mag(a))
    
    
    def reflect(a, b):
        return norm([2*a[1]*b[0]*b[1]+a[0]*(b[0]**2 - b[1]**2), 2*a[0]*b[0]*b[1]+a[1]*(-b[0]**2 + b[1]**2)])
    
    
    
    
    class ball:
        def __init__(self, x, y, vx, vy, mass,spr,index,ka,kv):
            self.r = [x, y]
            self.v = [vx, vy]
    
            self.radius = 5
            self.mass = mass
            self.spr=spr
            self.index = index
            self.ka=ka
            self.kv=kv
            
        def detectbounce(self,width,height):
            if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0] or  self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0] or self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1] or self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
                return True
            
    
        def bounce_walls(self, width, height):
            
            
            if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0]:
                self.v[0] *= -dampwall
    
            if self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0]:
                self.v[0] *= -dampwall
    
            if self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1]:
                self.v[1] *= -dampwall
    
            if self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
                self.v[1] *= -dampwall
        
        
    
        def update_r(self,v, h):
    
            self.r[0] += v[0] * h 
            self.r[1] += v[1] * h
        
        def un_update_r(self,v, h):
    
            self.r[0] += -v[0] * h 
            self.r[1] += -v[1] * h
    
        
        def KE(self):
            return 0.5 * self.mass * mag(self.v)**2
    
        def GPE(self):
            return self.mass * g * (-self.r[1] + height)
        
    
        def draw(self, screen, width, height):
            pygame.draw.circle(screen, (0, 0, 255), (self.r[0] +
                               width / 2, self.r[1] + height / 2), self.radius)
            
    
    
    
    #(self, x, y, vx, vy, mass,spr,index,ka,kv):
    # balls = [ball(1, 19, 0, 0,5,[1],0,[0,0,0,0],[0,0,0,0]), ball(250, 20, 0,0,1,[0],1,[0,0,0,0],[0,0,0,0])]   
    # springs = [[0, 1]]
    
    balls = [ball(1, 19, 0, 0,5,[1,3],0,[0,0,0,0],[0,0,0,0]), ball(250, 20, 0,0,2,[0,2,3],1,[0,0,0,0],[0,0,0,0]),ball(450, 0, 0,0,2,[1,3],1,[0,0,0,0],[0,0,0,0]),ball(250, -60, 0,0,2,[0,1,2],1,[0,0,0,0],[0,0,0,0])]   
    springs = [[0, 1],[1,2],[0,3],[1,3],[2,3]]
    
    
    
    
    
    
    
    
    
    def accel(b,BALLS):
    
        A=[0,g]
        for i in range(0,len(b.spr)):
            ball1=b
            ball2=BALLS[b.spr[i]]
            r1 = norm(list(np.array(ball2.r) - np.array(ball1.r)))
            lnow = dist(ball1.r, ball2.r)
            force = k * (lnow - L)
            A[0]+=force/ball1.mass*r1[0]
            A[1]+=force/ball1.mass*r1[1]
            
        return A
            
    initE=0
    while True:
        TIME.tick(200)
        SCREEN.fill((0, 0, 0))
    
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()
    
        #compute k1a and k1v for all balls
        for ball in balls:
    
                ball.ka[0]=accel(ball,balls)
                ball.kv[0]=ball.v
                
        #create newb1 based on 'updated' position of all balls with their own k1v
        newb=[]
        for ball in balls:
                ball.update_r(ball.kv[0], dt/2)
                newb.append(ball)
                ball.un_update_r(ball.kv[0], dt/2)
                
        #compute k2a and k2v for all balls based on newb1
        for ball in balls:
                ball.update_r(ball.kv[0], dt/2)
                ball.ka[1]=accel(ball,newb)
                ball.un_update_r(ball.kv[0], dt/2)
                
                ball.kv[1]=[ball.v[0]+0.5*dt*ball.ka[0][0],ball.v[1]+0.5*dt*ball.ka[0][1]]
    
        #create newb2 based on 'updated' position of all balls with their own k2v       
        newb=[]
        for ball in balls:
     
                ball.update_r(ball.kv[1], dt/2)
                newb.append(ball)
                ball.un_update_r(ball.kv[1], dt/2)
                
        #compute k3a and k3v for all balls
        for ball in balls:
            
                ball.update_r(ball.kv[1], dt/2)
                ball.ka[2]=accel(ball,newb)
                ball.un_update_r(ball.kv[1], dt/2)
                
                ball.kv[2]=[ball.v[0]+0.5*dt*ball.ka[1][0],ball.v[1]+0.5*dt*ball.ka[1][1]]
        
        newb=[]
        for ball in balls:
    
                ball.update_r(ball.kv[2], dt)
                newb.append(ball)
                ball.un_update_r(ball.kv[2], dt)
        
        #compute k4a and k4v for all balls
        for ball in balls:
                ball.update_r(ball.kv[2], dt)
                ball.ka[3]=accel(ball,newb)
                ball.un_update_r(ball.kv[2], dt)
                
                ball.kv[3]=[ball.v[0]+dt*ball.ka[2][0],ball.v[1]+dt*ball.ka[2][1]]
                
        #final stage of update
        for ball in balls:
            if ball.detectbounce(width,height)==True:
                ball.bounce_walls(width, height)
            else:
                ball.v=[ball.v[0]+dt*(ball.ka[0][0]+2*ball.ka[1][0]+2*ball.ka[2][0]+ball.ka[3][0])/6, ball.v[1]+dt*(ball.ka[0][1]+2*ball.ka[1][1]+2*ball.ka[2][1]+ball.ka[3][1])/6]
                ball.r=[ball.r[0]+dt*(ball.kv[0][0]+2*ball.kv[1][0]+2*ball.kv[2][0]+ball.kv[3][0])/6, ball.r[1]+dt*(ball.kv[0][1]+2*ball.kv[1][1]+2*ball.kv[2][1]+ball.kv[3][1])/6]
            
        for ball in balls:      
            ball.draw(SCREEN, width, height)
            for i in range(0,len(ball.spr)):
                ball1=ball
                ball2=balls[ball.spr[i]]
                pygame.draw.line(SCREEN, (0, 0, 155), (
                    ball1.r[0]+width/2, ball1.r[1]+height/2), (ball2.r[0]+width/2, ball2.r[1]+height/2))
        
        #check for energy        
                
        KE = 0
        EPE = 0
        GPE = 0
        for i in range(0, len(springs)):
    
            EPE += 1/2 * k * \
                (L - dist(balls[springs[i][0]].r,
                 balls[springs[i][1]].r))**2
    
        for i in range(0, len(balls)):
            KE += balls[i].KE()
            GPE += balls[i].GPE()
    
    
        if initE == 0:
                initE += KE+EPE+GPE
    
    
        text = font.render('init Energy: ' + str(round(initE,1))+' '+'KE: ' + str(round(KE, 1)) + ' '+'EPE: ' + str(round(EPE, 1))+' ' + 'GPE: ' + str(round(GPE, 1)) + ' ' + 'Total: ' + str(round(KE+EPE+GPE, 1)) + ' ' + 'Diff: ' + str(round((KE+EPE+GPE-initE), 1)),
                               True, (255, 255, 255))
    
        textRect = text.get_rect()
        textRect.center = (370, 70)
        SCREEN.blit(text, textRect)
                
    
        pygame.display.flip()

This is the edited, corrected by Lutz Lehmann and with some extra improvements:

import pygame
import sys
import math
import numpy as np


pygame.init()
width = 1200
height = 900
SCREEN = pygame.display.set_mode((width, height))
font = pygame.font.Font(None, 25)
TIME = pygame.time.Clock()

dampwall = 1
dt = 0.003
g = 5
k = 10
L = 200

digits = 6


def dist(a, b):
    return math.sqrt((a[0] - b[0])*(a[0] - b[0]) + (a[1] - b[1])*(a[1] - b[1]))


def mag(a):
    return dist(a, [0, 0])


def dp(a, b):
    return a[0]*b[0]+a[1]*b[1]


def norm(a):
    return list(np.array(a)/mag(a))


def reflect(a, b):
    return norm([2*a[1]*b[0]*b[1]+a[0]*(b[0]**2 - b[1]**2), 2*a[0]*b[0]*b[1]+a[1]*(-b[0]**2 + b[1]**2)])


class Ball:
    def __init__(self, x, y, vx, vy, mass, spr, index, ka, kv):
        self.r = [x, y]
        self.v = [vx, vy]

        self.radius = 5
        self.mass = mass
        self.spr = spr
        self.index = index
        self.ka = ka
        self.kv = kv

    def copy(self):
        return Ball(self.r[0], self.r[1], self.v[0], self.v[1], self.mass, self.spr, self.index, self.ka, self.kv)

    def detectbounce(self, width, height):
        if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0] or self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0] or self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1] or self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
            return True

    def bounce_walls(self, width, height):

        if self.r[0] + self.radius > width/2 and self.r[0]+self.v[0] > self.r[0]:
            self.v[0] *= -dampwall

        if self.r[0] - self.radius < -width/2 and self.r[0]+self.v[0] < self.r[0]:
            self.v[0] *= -dampwall

        if self.r[1] + self.radius > height/2 and self.r[1]+self.v[1] > self.r[1]:
            self.v[1] *= -dampwall

        if self.r[1] - self.radius < -height/2 and self.r[1]+self.v[1] < self.r[1]:
            self.v[1] *= -dampwall

    def update_r(self, v, h):

        self.r[0] += v[0] * h
        self.r[1] += v[1] * h

    def un_update_r(self, v, h):

        self.r[0] += -v[0] * h
        self.r[1] += -v[1] * h

    def KE(self):
        return 0.5 * self.mass * mag(self.v)**2

    def GPE(self):
        return self.mass * g * (-self.r[1] + height)

    def draw(self, screen, width, height):
        pygame.draw.circle(screen, (0, 0, 255), (self.r[0] +
                           width / 2, self.r[1] + height / 2), self.radius)


# (self, x, y, vx, vy, mass,spr,index,ka,kv):


# balls = [Ball(1, 19, 0, 0, 1, [1], 0, [0, 0, 0, 0], [0, 0, 0, 0]),
#          Ball(250, 20, 0, 0, 1, [0], 1, [0, 0, 0, 0], [0, 0, 0, 0])]
# springs = [[0, 1]]

balls = [Ball(1, 19, 0, 0,5,[1,3],0,[0,0,0,0],[0,0,0,0]), Ball(250, 20, 0,0,2,[0,2,3],1,[0,0,0,0],[0,0,0,0]),Ball(450, 0, 0,0,2,[1,3],1,[0,0,0,0],[0,0,0,0]),Ball(250, -60, 0,0,2,[0,1,2],1,[0,0,0,0],[0,0,0,0])]

# n=5
# resprings=[]

# for i in range(0,n):
#     for j in range(0,n):
#         if i==0 and j==0:
#             resprings.append([1,2,n,n+1,2*n])
#         if i==n and j==0:
#             resprings.apend([n*(n-1)+1,n*(n-1)+2,n*(n-2),n*(n-3),n*(n-2)+1])
#         if j==0 and i!=0 or i!=n:
#             resprings.append([(i-1)*n+1,(i-1)*n+2,(i-2)*n,(i-2)*n+1,(i)*n,(i)*n+1])
        
            

def getsprings(B):
    S=[]
    for i in range(0,len(B)):
        theball=B[i]
        for j in range(len(theball.spr)):
            spring=sorted([i,theball.spr[j]])
            if spring not in S:
                S.append(spring)

    return S
            
    
springs = getsprings(balls)    
        
    





def accel(b, BALLS):

    A = [0, g]
    for i in range(0, len(b.spr)):
        ball1 = b
        ball2 = BALLS[b.spr[i]]
        r1 = norm(list(np.array(ball2.r) - np.array(ball1.r)))
        lnow = dist(ball1.r, ball2.r)
        force = k * (lnow - L)
        A[0] += force/ball1.mass*r1[0]
        A[1] += force/ball1.mass*r1[1]

    return A


initE = 0
while True:
    TIME.tick(200)
    SCREEN.fill((0, 0, 0))

    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()
            sys.exit()
    for ball in balls:
        ball.bounce_walls(width, height)

    # compute k1a and k1v for all balls
    for ball in balls:

        ball.ka[0] = accel(ball, balls)
        ball.kv[0] = ball.v

    # create newb1 based on 'updated' position of all balls with their own k1v
    newb = []
    for ball in balls:
        ball.update_r(ball.kv[0], dt/2)
        newb.append(ball.copy())
        ball.un_update_r(ball.kv[0], dt/2)

    # compute k2a and k2v for all balls based on newb1
    for ball in balls:
        ball.update_r(ball.kv[0], dt/2)
        ball.ka[1] = accel(ball, newb)
        ball.un_update_r(ball.kv[0], dt/2)

        ball.kv[1] = [ball.v[0]+0.5*dt*ball.ka[0]
                      [0], ball.v[1]+0.5*dt*ball.ka[0][1]]

    # create newb2 based on 'updated' position of all balls with their own k2v
    newb = []
    for ball in balls:

        ball.update_r(ball.kv[1], dt/2)
        newb.append(ball.copy())
        ball.un_update_r(ball.kv[1], dt/2)

    # compute k3a and k3v for all balls
    for ball in balls:

        ball.update_r(ball.kv[1], dt/2)
        ball.ka[2] = accel(ball, newb)
        ball.un_update_r(ball.kv[1], dt/2)

        ball.kv[2] = [ball.v[0]+0.5*dt*ball.ka[1]
                      [0], ball.v[1]+0.5*dt*ball.ka[1][1]]

    newb = []
    for ball in balls:

        ball.update_r(ball.kv[2], dt)
        newb.append(ball.copy())
        ball.un_update_r(ball.kv[2], dt)

    # compute k4a and k4v for all balls
    for ball in balls:
        ball.update_r(ball.kv[2], dt)
        ball.ka[3] = accel(ball, newb)
        ball.un_update_r(ball.kv[2], dt)

        ball.kv[3] = [ball.v[0]+dt*ball.ka[2][0], ball.v[1]+dt*ball.ka[2][1]]

    # final stage of update
    for ball in balls:
        ball.v = [ball.v[0]+dt*(ball.ka[0][0]+2*ball.ka[1][0]+2*ball.ka[2][0]+ball.ka[3][0])/6,
                  ball.v[1]+dt*(ball.ka[0][1]+2*ball.ka[1][1]+2*ball.ka[2][1]+ball.ka[3][1])/6]
        ball.r = [ball.r[0]+dt*(ball.kv[0][0]+2*ball.kv[1][0]+2*ball.kv[2][0]+ball.kv[3][0])/6,
                  ball.r[1]+dt*(ball.kv[0][1]+2*ball.kv[1][1]+2*ball.kv[2][1]+ball.kv[3][1])/6]

    for ball in balls:
        ball.draw(SCREEN, width, height)
        for i in range(0, len(ball.spr)):
            ball1 = ball
            ball2 = balls[ball.spr[i]]
            pygame.draw.line(SCREEN, (0, 0, 155), (
                ball1.r[0]+width/2, ball1.r[1]+height/2), (ball2.r[0]+width/2, ball2.r[1]+height/2))

    # check for energy

    KE = 0
    EPE = 0
    GPE = 0
    for i in range(0, len(springs)):

        EPE += 1/2 * k * \
            (L - dist(balls[springs[i][0]].r,
             balls[springs[i][1]].r))**2

    for i in range(0, len(balls)):
        KE += balls[i].KE()
        GPE += balls[i].GPE()

    if initE == 0:
        initE += KE+EPE+GPE
    
    
    text1 = font.render(f"initial energy: {str(round(initE, digits))}", True, (255, 255, 255))
    text2 = font.render(f"kinetic energy: {str(round(KE, digits))}", True, (255, 255, 255))
    text3 = font.render(f"elastic potential energy: {str(round(EPE, digits))}", True, (255, 255, 255))
    text4 = font.render(f"gravitational energy: {str(round(GPE, digits))}", True, (255, 255, 255))
    text5 = font.render(f"total energy: {str(round(KE + EPE + GPE, digits))}", True, (255, 255, 255))
    text6 = font.render(f"change in energy: {str(round(KE + EPE + GPE - initE, digits))}", True, (255, 255, 255))

    SCREEN.blit(text1, (10, 10))
    SCREEN.blit(text2, (10, 60))
    SCREEN.blit(text3, (10, 110))
    SCREEN.blit(text4, (10, 160))
    SCREEN.blit(text5, (10, 210))
    SCREEN.blit(text6, (10, 260))
    

    pygame.display.flip()

Solution

  • The immediate error seems to be this

        for ball in balls:
                ...
                newb1.append(ball)
                ...
    

    as ball is just a reference to the class ball instance, thus newb1 is a list of references to the objects in balls, it makes no difference if you manipulate the one or the other, it is always the same data records that get changed.

    You need to apply a copy mechanism, as you have lists of lists, you need a deep copy, or a dedicated copy member method, else you just copy the array references in the ball instances, so you get different instances, but pointing to the same arrays.

    It is probably not an error but still a bad idea to have the class name also as variable name in the same scope.