Search code examples
pythonasserttestbook

None value from a function after assert checking


I am writing a function that checks the output of a notebook file and checks against an assert expression to match the output. But I get a None value when I print the output from the function.

from testbook import testbook

@testbook('ssnn_solved.ipynb', execute=True)
def test_stdout(tb):
    
    score = 0.0
    
    assert tb.cell_output_text(7) == 'Test passed.' 
    return score+1
    
    assert tb.cell_output_text(13) == 'Test passed.'
    return score+1
        
    assert tb.cell_output_text(17) == 'Test passed.'
    return score+1  
    
    assert tb.cell_output_text(21) == 'Test passed.'
    return score+1
    
res = test_stdout()

print(res)

I want to return the output as 4.0 after running the assert.


Solution

  • If your goal is to count the number of successful assertions, it's either 4 or you can't get it because an exception was raised. If you use regular checks instead of assertions, you can change that, because the function won't raise an error on the first falsy check it encounters:

    @testbook('ssnn_solved.ipynb', execute=True)
    def test_stdout(tb):
        score = 0
        score += tb.cell_output_text(7) == 'Test passed.' 
        score += tb.cell_output_text(13) == 'Test passed.'
        score += tb.cell_output_text(17) == 'Test passed.'
        score += tb.cell_output_text(21) == 'Test passed.'
        return score
    

    If you want to short-circuit and return immediately on an error, a return after each test is more appropriate. You can use a loop at that point, to (arguably) simplify the code:

    @testbook('ssnn_solved.ipynb', execute=True)
    def test_stdout(tb):
        cells = [7, 13, 17, 21]
        score = 0
        for cell in cells:
            if tb.cell_output_text(cell) != 'Test passed.':
                break
            score += 1
        return score
    

    Using this structure, you could rewrite the first example using sum:

    @testbook('ssnn_solved.ipynb', execute=True)
    def test_stdout(tb):
        return sum(tb.cell_output_text(cell) == 'Test passed.' for cell in [7, 13, 17, 21])