Search code examples
pythonray

Ray has weird time consuming


I have a tiny Ray pipeline like this:

import ray
import numpy as np
import time

@ray.remote
class PersonDetector:
    
    def __init__(self) -> None:
        self.model = self._init_model()
    
    def _init_model(self):
        s = np.random.random([100, 4])
        return s

    def infer(self, img):
        b = self.model[0:np.random.randint(100), :]
        # random batch boxes
        print(b.shape)
        time.sleep(4)
        return b

@ray.remote
class KptsDetector:

    def __init__(self) -> None:
        self.model = self._init_model()
    
    def _init_model(self):
        s = np.random.random([100, 17])
        return s
    
    def infer(self, img, boxes):
        sh1 = boxes.shape[0]
        kpts = self.model[0: sh1, :]
        time.sleep(2)
        return kpts


@ray.remote
class HandDetector:
    
    def __init__(self) -> None:
        self.model = self._init_model()
    
    def _init_model(self):
        s = np.random.random([100, 4])
        return s

    def infer(self, img):
        b = self.model[0:np.random.randint(100), :]
        # random batch boxes
        data = {}
        data['hands'] = b
        time.sleep(3)
        return data

@ray.remote
def gather_all(hands, kpts):
    t0 = time.time()
    if isinstance(hands, dict):
        print(f'in hands info: {hands.keys()}')
        hands = hands['hands']
    if hands.shape[0] > kpts.shape[0]:
        out = hands.copy()
        out[:kpts.shape[0], :] += kpts[..., :4]
    else:
        out = kpts.copy()[..., :4]
        out[:hands.shape[0], :] += hands
    print(f'[gather time] {time.time() - t0}')
    return out

# how to written DAG in classes?
P = PersonDetector.remote()
K = KptsDetector.remote()
H = HandDetector.remote()

t0 = time.time()
img = []
boxes = P.infer.remote(img)
hands = H.infer.remote(img)
kpts = K.infer.remote(img, boxes)

# out = gather.remote(hands, kpts)
out = gather_all.remote(hands, kpts)
t1 = time.time()
print(t1 - t0)

out = ray.get(out)
t2 = time.time()
print(t2 - t0)
print(t2 - t1)
print(out.shape)


I using time.sleep() for fake time consuming. As you can see, the HandDetector should running in a sub process, so the whole time should be 6s.

But I got (you can have a try on your computer):

6.45377516746521
6.4489240646362305

Why there are 0.4s time more?


Solution

  • It looks like you posted the same question to the Ray GitHub (link to comment). I am copying the answer here so other StackOverflow users can benefit.

    @robertnishihara: You are recreating the actors every time you call do_it, and there is some overhead to creating actors (it starts a new Python process).

    If you rewrite it as follows, the overhead drops.

    import ray
    import numpy as np
    import time
    
    ray.init()
    
    @ray.remote
    class PersonDetector:
        
        def __init__(self) -> None:
            self.model = self._init_model()
        
        def _init_model(self):
            s = np.random.random([100, 4])
            return s
    
        def infer(self, img):
            b = self.model[0:np.random.randint(100), :]
            # random batch boxes
            print(b.shape)
            time.sleep(4)
            return b
    
    @ray.remote
    class KptsDetector:
    
        def __init__(self) -> None:
            self.model = self._init_model()
        
        def _init_model(self):
            s = np.random.random([100, 17])
            return s
        
        def infer(self, img, boxes):
            sh1 = boxes.shape[0]
            kpts = self.model[0: sh1, :]
            time.sleep(2)
            return kpts
    
    
    @ray.remote
    class HandDetector:
        
        def __init__(self) -> None:
            self.model = self._init_model()
        
        def _init_model(self):
            s = np.random.random([100, 4])
            return s
    
        def infer(self, img):
            b = self.model[0:np.random.randint(100), :]
            # random batch boxes
            data = {}
            data['hands'] = b
            time.sleep(3)
            return data
    
    @ray.remote
    def gather_all(hands, kpts):
        t0 = time.time()
        if isinstance(hands, dict):
            print(f'in hands info: {hands.keys()}')
            hands = hands['hands']
        if hands.shape[0] > kpts.shape[0]:
            out = hands.copy()
            out[:kpts.shape[0], :] += kpts[..., :4]
        else:
            out = kpts.copy()[..., :4]
            out[:hands.shape[0], :] += hands
        print(f'[gather time] {time.time() - t0}')
        return out
    
    
    def do_it(i, P, K, H):
        t0 = time.time()
        img = []
        boxes = P.infer.remote(img)
        hands = H.infer.remote(img)
        kpts = K.infer.remote(img, boxes)
    
        # out = gather.remote(hands, kpts)
        out = gather_all.remote(hands, kpts)
        t1 = time.time()
        out_value = ray.get(out)
        t2 = time.time()
        print(f'{i}. {t2 - t1}')
        print(out_value.shape)
    
    
    if __name__ == '__main__':
        P = PersonDetector.remote()
        K = KptsDetector.remote()
        H = HandDetector.remote()    
        for i in range(3):
            do_it(i, P, K, H)
    

    In this case, I get the output

    (PersonDetector pid=36472) (35, 4)
    0. 7.073965072631836
    (95, 4)
    (PersonDetector pid=36472) (34, 4)
    (gather_all pid=36421) in hands info: dict_keys(['hands'])
    (gather_all pid=36421) [gather time] 6.890296936035156e-05
    1. 6.016530275344849
    (47, 4)
    (PersonDetector pid=36472) (87, 4)
    (gather_all pid=36421) in hands info: dict_keys(['hands'])
    (gather_all pid=36421) [gather time] 9.608268737792969e-05
    2. 6.006845235824585
    (88, 4)
    
    (gather_all pid=36421) in hands info: dict_keys(['hands'])
    (gather_all pid=36421) [gather time] 9.870529174804688e-05