Search code examples
pythonapache-beamptransform

How can I use `DoOutputsTuple` as either the parameter or the return of a `PTransform` in Apache Beam


I have some DoFns that return multiple output PCollections, where one represents the "good" data path and the other captures errors to redirect them. Something like the following:

output, error = (
        pcol
        | "Fix Timestamps"
        >> ParDo(ConvertTimestamp(), timestamp_field)
        .with_outputs(self.fail_tag, main=self.success_tag)

But I want to simplify and standardize, putting several of the ParDo/DoFns into a single PTransform that can be called directly. Something like the following:

class ConvertToBQFriendlyTypes(PTransform):
    def __init__(
        self,
        timestamp_fields: tuple[str, ...],
        fail_tag: str = FAIL_TAG,
        success_tag: str = SUCCESS_TAG,
    ):
        super().__init__()
        self.fail_tag = fail_tag
        self.success_tag = success_tag
        self.timestamp_fields = timestamp_fields

    class _ConvertSingleTimestamp(DoFn):
        def __init__(self, fail_tag: str = FAIL_TAG):
            super().__init__()
            self.fail_tag = fail_tag

        def process(
            self,
            element: dict,
            field_name: str,
        ) -> Iterable[dict[str, Any]] | Iterable[pvalue.TaggedOutput]:
            timestamp_raw = element[field_name]
            if hasattr(timestamp_raw, "to_utc_datetime"):
                timestamp_utc = timestamp_raw.to_utc_datetime(has_tz=True)  # type: ignore
            else:
                timestamp_utc = timestamp_raw
            if hasattr(timestamp_utc, "timestamp"):
                timestamp_utc = datetime.fromtimestamp(round(timestamp_utc.timestamp()))
            if hasattr(timestamp_utc, "strftime"):
                result = timestamp_utc.strftime("%Y-%m-%d %H:%M:%S.%f")  # type: ignore
            elif isinstance(timestamp_utc, str) or timestamp_utc is None:
                result = timestamp_utc
            else:
                result = Failure( # `Failure` is a simple data class with these attributes:
                    pipeline_step="ConvertToBQFriendlyTypes",
                    element=element,
                    exception=TimestampError(
                        f'Field "{field_name}" has no means to convert time to a '
                        "string, which is needed for writing to BigQuery."
                    ),
                )
            if isinstance(result, Failure):
                yield pvalue.TaggedOutput(self.fail_tag, result)
            else:
                element[field_name] = result
                yield element

    def expand(
        self, pcoll: PCollection[dict[str, Any]] | pvalue.PValue
    ) -> PCollection[dict[str, Any]] | pvalue.PValue:
        for timestamp_field in self.timestamp_fields:
            pcoll = pcoll | f'Convert "{timestamp_field}"' >> ParDo(
                self._ConvertSingleTimestamp(self.fail_tag), timestamp_field
            ).with_outputs(self.fail_tag, main=self.success_tag)
        return pcoll

But this fails because a PTransforms expand method can only accept a PCollection, not a DoOutputsTuple object, like what is received from any ParDo that returns multiple outputs using with_outputs. Moreover, it seems that a PTransform is only allowed to return a PCollection, not a DoOutputsTuple.

I have tried a few different ways to manage this, either splitting the outputs and managing each separately (in which case, I get either TypeError: cannot unpack non-iterable PCollection object or TypeError: 'PCollection' object is not subscriptable, depending upon how I try to unpack it) or by working directly with the unpacked DoOutputsTuple (in which case, I get the dreaded TypeError: '_InvalidUnpickledPCollection' object is not subscriptable... error).


Does anyone know how I can create a PTransform that both receives a DoOutputsTuple as the first argument to its expand and returns a DoOutputsTuple? If not, does anyone have any better suggestions for how to manage this?

I am tempted to use the pasgarde package as it seems elegant, but (a) I don't want to create a dependency upon a lightly adopted open source package, and (b) it boxes me too tightly into only using Map, Flatmap, and Filter.


Solution

  • That is something I also struggled recently. For solving this problem, I have looked at build-in PTransforms and tried to write something similar.

    For example, CoGroupByKey (see docs) expects two individual PCollections as input. While WriteToBigQuery (see here) returns multiple results.

    CoGroupByKey utilizes Python dictionaries, while WriteToBigQuery uses a custom class to return the output. The latter was (for my use case) a bit overkill, so I used a dictionary for the output as well.

    As a MWE consider

    import apache_beam as beam
    
    class DoFnWithOutputs(beam.DoFn):
      def process(self, element):
        if element == 1:
          yield "one"
        else:
          yield beam.pvalue.TaggedOutput("not_one", False)
    
    
    class SecondDoFn(beam.DoFn):
      def process(self, element):
        yield element + 5
    
          
    class MyPtransform(beam.PTransform):
      def expand(self, pcoll):
    
        my_tagged_output = (
          pcoll['first_input']
          | beam.ParDo(DoFnWithOutputs()).with_outputs("not_one", main="one")
        )
    
        my_other_output = (
          pcoll['second_input']
          | beam.ParDo(SecondDoFn())
        )
        # we need to define an output dict
        # it might be possible to just use one k/v pair for the tagged output (not tested)  
        return {
          'not_one': my_tagged_output['not_one'],
          'one': my_tagged_output['one'],
          'other': my_other_output
        }      
            
    
    with beam.Pipeline() as pipeline:
      first_input = pipeline | 'Create first input' >> beam.Create([1, 2, 3, 1])
    
      second_input = pipeline | 'Create second input' >> beam.Create([1, 2, 3, 4])
    
      ptransform_output = (
        {'first_input': first_input, 'second_input': second_input}  # Create the input dict
        | MyPtransform()
      )        
      
      (
        ptransform_output['one']
        | "map 1" >> beam.Map(lambda x: f"one: {x}")    
        | "print 1" >> beam.Map(print)
      )
    
      (
        ptransform_output['not_one']
        | "map != 1" >> beam.Map(lambda x: f"not_one: {x}") 
        | "print != 1" >> beam.Map(print)
      )
    
      (
        ptransform_output['other']
        | "map other" >> beam.Map(lambda x: f"other: {x}") 
        | "print other" >> beam.Map(print)
      )  
      
    

    which can be executed here.