Search code examples
pythonjython

Inefficient Random Dice Roll in Python/Jython


As I learn Python (Specifically Jython, if the difference is important here) I'm writing a simple terminal game that uses skills and dice rolls based on the level of those skills to determine success/fail at an attempted action. I hope to use this code eventually in a larger game project.

Under a stress test, the code uses .5GB of ram and seems to take quite a while to get a result (~50 seconds). It could just be that the task is really that intensive but as a noob I'm betting I'm just doing things inefficiently. Could anyone give some tips on both:

  • how to improve the efficiency of this code

  • and how to write this code in a more pythonic way?

    import random
    
    def DiceRoll(maxNum=100,dice=2,minNum=0):
      return sum(random.randint(minNum,maxNum) for i in xrange(dice))
    
    def RollSuccess(max):
      x = DiceRoll()
      if(x <= (max/10)):
        return 2
      elif(x <= max):
        return 1
      elif(x >= 100-(100-max)/10):
        return -1
      return 0
    
    def RollTesting(skill=50,rolls=10000000):
      cfail = 0
      fail = 0
      success = 0
      csuccess = 0
      for i in range(rolls+1):
        roll = RollSuccess(skill)
        if(roll == -1):
          cfail = cfail + 1
        elif(roll == 0):
          fail = fail + 1
        elif(roll == 1):
          success = success + 1
        else:
          csuccess = csuccess + 1
      print "CFails: %.4f. Fails: %.4f. Successes: %.4f. CSuccesses: %.4f." % (float(cfail)/float(rolls), float(fail)/float(rolls), float(success)/float(rolls), float(csuccess)/float(rolls))
    
    RollTesting()
    

EDIT - Here's my code now:

from random import random

def DiceRoll():
   return 50 * (random() + random())

def RollSuccess(suclim):
  x = DiceRoll()
  if(x <= (suclim/10)):
    return 2
  elif(x <= suclim):
    return 1
  elif(x >= 90-suclim/10):
    return -1
  return 0

def RollTesting(skill=50,rolls=10000000):
  from time import clock
  start = clock()
  cfail = fail = success = csuccess = 0.0
  for _ in xrange(rolls):
    roll = RollSuccess(skill)
    if(roll == -1):
      cfail += 1
    elif(roll == 0):
      fail += 1
    elif(roll == 1):
      success += 1
    else:
      csuccess += 1
  stop = clock()
  print "Last time this statement was manually updated, DiceRoll and RollSuccess totaled 12 LOC."
  print "It took %.3f seconds to do %d dice rolls and calculate their success." % (stop-start,rolls)
  print "At skill level %d, the distribution is as follows" % (skill)
  print "CFails: %.4f. Fails: %.4f. Successes: %.4f. CSuccesses: %.4f." % (cfail/rolls, fail/rolls, success/rolls, csuccess/rolls)

RollTesting(50)

And the output:

Last time this statement was manually updated, DiceRoll and RollSuccess totaled 12 LOC.
It took 6.558 seconds to do 10000000 dice rolls and calculate their success.
At skill level 50, the distribution is as follows
CFails: 0.0450. Fails: 0.4548. Successes: 0.4952. CSuccesses: 0.0050.

Noticeably this isn't equivalent because I changed the random calculation enough to be noticeably different output (the original was supposed to be 0-100, but I forgot to divide by the amount of dice). The mem usage looks to be ~.2GB now. Also the previous implementation couldn't do 100mil tests, I've ran this one at up to 1bil tests (it took 8 minutes, and the mem usage doesn't seem significantly different).


Solution

  • You're doing 10 million loops. Just the looping costs are probably 10% of your total time. Then, if the whole loop doesn't fit into cache at once, it may slow things down even more.

    Is there a way to avoid doing all those loops in Python? Yes, you can do them in Java.

    The obvious way to do that is to actually write and call Java code. But you don't have to do that.


    A list comprehension, or a generator expression driven by a native builtin, will also do the looping in Java. So, on top of being more compact and simpler, this should also be faster:

    attempts = (RollSuccess(skill) for i in xrange(rolls))
    counts = collections.Counter(attempts)
    cfail, fail, success, csuccess = counts[-1], counts[0], counts[1], counts[2]
    

    Unfortunately, while this does seem to be faster in Jython 2.7b1, it's actually slower in 2.5.2.


    Another way to speed up loops is to use a vectorization library. Unfortunately, I don't know what Jython people use for this, but in CPython with numpy, it looks something like this:

    def DiceRolls(count, maxNum=100, dice=2, minNum=0):
        return sum(np.random.random_integers(minNum, maxNum, count) for die in range(dice))
    
    def RollTesting(skill=50, rolls=10000000):
        dicerolls = DiceRolls(rolls)
        csuccess = np.count_nonzero(dicerolls <= skill/10)
        success = np.count_nonzero((dicerolls > skill/10) & (dicerolls <= skill))
        fail = np.count_nonzero((dicerolls > skill) & (dicerolls <= 100-(100-skill)/10))
        cfail = np.count_nonzero((dicerolls > 100-(100-skill)/10)
    

    This speeds things up by a factor of about 8.

    I suspect that in Jython things aren't nearly as nice as with numpy, and you're expected to import Java libraries like the Apache Commons numerics or PColt and figure out the Java-vs.-Python issues yourself… but better to search and/or ask than to assume.


    Finally, you may want to use a different interpreter. CPython 2.5 or 2.7 doesn't seem to be much different from Jython 2.5 here, but it does mean you can use numpy to get an 8x improvement. PyPy 2.0, meanwhile, is 11x faster, with no changes.

    Even if you need to do your main program in Jython, if you've got something slow enough to dwarf the cost of starting a new process, you can move it to a separate script that you run via subprocess. For example:

    subscript.py:

    # ... everything up to the RollTesting's last line
        return csuccess, success, fail, cfail
    
    skill = int(sys.argv[1]) if len(sys.argv) > 1 else 50
    rolls = int(sys.argv[2]) if len(sys.argv) > 2 else 10000000
    csuccess, success, fail, cfail = RollTesting(skill, rolls)
    print csuccess
    print success
    print fail
    print cfail
    

    mainscript.py:

    def RollTesting(skill, rolls):
        results = subprocess32.check_output(['pypy', 'subscript.py', 
                                             str(skill), str(rolls)])
        csuccess, success, fail, cfail = (int(line.rstrip()) for line in results.splitlines())
        print "CFails: %.4f. Fails: %.4f. Successes: %.4f. CSuccesses: %.4f." % (float(cfail)/float(rolls), float(fail)/float(rolls), float(success)/float(rolls), float(csuccess)/float(rolls))
    

    (I used the subprocess32 module to get the backport of check_output, which isn't available in Python 2.5, Jython or otherwise. You could also just borrow the source for check_output from 2.7's implementation.)

    Note that Jython 2.5.2 has some serious bugs in subprocess (which will be fixed in 2.5.3 and 2.7.0, but that doesn't help you today). But fortunately, they don't affect this code.

    In a quick test, the overhead (mostly spawning a new interpreter process, but there's also marshalling the parameters and results, etc.) added more than 10% to the cost, meaning I only got a 9x improvement instead of 11x. And that will be a little worse on Windows. But not enough to negate the benefits for any script that's taking on the order of a minute to run.


    Finally, if you're doing more complicated stuff, you can use execnet, which wraps up Jython<->CPython<->PyPy to let you use whatever's best in each part of the code without having to do all that explicit subprocess stuff.