Search code examples
pythongenerator

Can a python function be both a generator and a "non-generator"?


I have a function which I want to yield bytes from (generator behaviour) and also write to a file (non-generator behaviour) depending on whether the save boolean is set. Is that possible?

def encode_file(source, save=False, destination=None):
    # encode the contents of an input file 3 bytes at a time
    print('hello')
    with open(source, 'rb') as infile:
        # save bytes to destination file
        if save:
            print(f'saving to file {destination}')
            with open(destination, 'wb') as outfile:
                while (bytes_to_encode := infile.read(3)):
                    l = len(bytes_to_encode)
                    if l < 3:
                        bytes_to_encode += (b'\x00' * (3 - l))
                    outfile.write(bytes_to_encode)
            return
        # yield bytes to caller
        else:
            while (bytes_to_encode := infile.read(3)):
                l = len(bytes_to_encode)
                if l < 3:
                    bytes_to_encode += (b'\x00' * (3 - l)) # pad bits if short
                yield encode(bytes_to_encode)
            return

In the above implementation, the function always behaves as a generator. When I call

encode_file('file.bin', save=True, destination='output.base64')

it does not print "hello" instead, it returns a generator object. This does not make sense to me. Shouldn't "hello" be printed and then shouldn't control be directed to the if save: portion of the code thus avoiding the part of the function that yields completely?


Solution

  • A function can’t be a generator and also not be one, but of course you can decide whether to return a generator object or not by defining a helper function. To avoid duplicating the (read) with between the two (and reduce redundancy in general), make one branch a client of the other:

    def encode_file(source, save=False, destination=None):
        # encode the contents of an input file 3 bytes at a time
        print('hello')
        # save bytes to destination file
        if save:
            print(f'saving to file {destination}')
            with open(destination, 'wb') as outfile:
                for bytes_to_encode in encode_file(source):
                    outfile.write(bytes_to_encode)
        # yield bytes to caller
        else:
            def g():
                with open(source, 'rb') as infile:
                    while (bytes_to_encode := infile.read(3)):
                        l = len(bytes_to_encode)
                        if l < 3:
                            bytes_to_encode += (b'\x00' * (3 - l)) # pad bits if short
                        yield encode(bytes_to_encode)
            return g()
    

    (Thanks to interjay for pointing out the need for the with in g.)