Search code examples
pythonpython-3.xtensorflowreinforcement-learning

Custom Early Stop Function - Stop When Cost Value Starts Accelerating Upward After Convergence?


I am training a model using Tensorflow in Python 3, and have set up my own separate early stopping function. My model keeps the cost value fairly low for most of the training run, but then like normal, it reaches a certain point where not only does it no longer improve/minimize the cost function, but it gets exponentially worse and accelerates up. I have attached the values of my costs below.

I'm wondering if someone has an idea (pseudo-code, brainstorm, or link that I haven't found yet) of a way to improve my early stopping function to catch when this acceleration happens, and enforce the early stop. I don't necessarily want to have just a static number (like > 1.000) in case it hits that number but isn't done searching below. Maybe have some sort of an acceleration monitoring? A moving average? As you will see from the values and image, the acceleration is generally quite extreme at the end and will happen eventually without fail every training run. I'd like to be able to catch it as soon as possible, but still ensure that the move is drastic enough to enforce the stop. Thanks!

Image of Cost Acceleration

epoch:  1   cost:   0.032336
epoch:  2   cost:   0.015083
epoch:  3   cost:   0.003783
epoch:  4   cost:   0.011579
epoch:  5   cost:   0.00436
epoch:  6   cost:   0.003667
epoch:  7   cost:   0.000973
epoch:  8   cost:   0.002916
epoch:  9   cost:   0.016516
epoch:  10  cost:   0.00094
epoch:  11  cost:   0.000656
epoch:  12  cost:   0.001112
epoch:  13  cost:   0.000761
epoch:  14  cost:   0.002976
epoch:  15  cost:   0.004531
epoch:  16  cost:   0.00247
epoch:  17  cost:   0.005809
epoch:  18  cost:   0.011614
epoch:  19  cost:   0.004681
epoch:  20  cost:   0.002704
epoch:  21  cost:   0.001122
epoch:  22  cost:   0.109581
epoch:  23  cost:   0.001352
epoch:  24  cost:   0.000767
epoch:  25  cost:   0.009472
epoch:  26  cost:   0.003918
epoch:  27  cost:   0.007462
epoch:  28  cost:   0.002033
epoch:  29  cost:   0.004985
epoch:  30  cost:   0.006285
epoch:  31  cost:   0.004838
epoch:  32  cost:   0.008076
epoch:  33  cost:   0.008414
epoch:  34  cost:   0.008761
epoch:  35  cost:   0.002719
epoch:  36  cost:   0.002752
epoch:  37  cost:   0.00355
epoch:  38  cost:   0.012253
epoch:  39  cost:   0.052947
epoch:  40  cost:   0.005952
epoch:  41  cost:   0.012556
epoch:  42  cost:   0.018322
epoch:  43  cost:   0.042715
epoch:  44  cost:   0.045315
epoch:  45  cost:   0.051732
epoch:  46  cost:   0.072919
epoch:  47  cost:   0.013907
epoch:  48  cost:   0.088789
epoch:  49  cost:   0.045083
epoch:  50  cost:   0.038073
epoch:  51  cost:   0.033848
epoch:  52  cost:   0.022773
epoch:  53  cost:   0.198873
epoch:  54  cost:   0.020925
epoch:  55  cost:   0.02264
epoch:  56  cost:   0.039353
epoch:  57  cost:   0.055266
epoch:  58  cost:   0.057254
epoch:  59  cost:   0.048848
epoch:  60  cost:   0.072187
epoch:  61  cost:   0.066818
epoch:  62  cost:   0.111698
epoch:  63  cost:   0.121994
epoch:  64  cost:   0.216178
epoch:  65  cost:   0.4132
epoch:  66  cost:   0.243138
epoch:  67  cost:   0.628117
epoch:  68  cost:   0.349325
epoch:  69  cost:   0.413678
epoch:  70  cost:   0.376448
epoch:  71  cost:   0.931199
epoch:  72  cost:   5.495036
epoch:  73  cost:   2.914621
epoch:  74  cost:   7.160439
epoch:  75  cost:   13.324359
epoch:  76  cost:   22.426832
epoch:  77  cost:   116.921036
epoch:  78  cost:   285.824371

Solution

  • You can do this by keeping a window of last n loss values and calculating a range (max minus min of the window). Then you put a threshold, if the range value is bigger than m times of the min of this window, then you just stop.