Search code examples
rustmatchxlsxwriterpyo3

Matching multiple possible types?


I'm very, very new to Rust and struggling with it because of my strong weakly typed programming background.

The code below should write data being received from Python via PYO3 into a XLSX worksheet. I just don't know how to handle the last match, because "value" is of type PyAny (this is, its method extract can output multiple types such as String, f32, etc. and I want a specific behavior depending on the extracted type).

Maybe I could just chain matches for each potential extracted type (if first outputs Err, try the next), but I suspect there could be a better way. Maybe I'm just approaching the problem with a wrong design. Any insights will be welcome.

pub trait WriteValue {
    fn write_value(&self, worksheet: &mut Worksheet, row: u32, col: u16, format: Option<&Format>) -> Result<(), XlsxError>;
}

impl WriteValue for String {
    fn write_value(&self, worksheet: &mut Worksheet, row: u32, col: u16, format: Option<&Format>) -> Result<(), XlsxError> {
        worksheet.write_string(row, col, &self, format)
    }
}

impl WriteValue for f32 {
    fn write_value(&self, worksheet: &mut Worksheet, row: u32, col: u16, format: Option<&Format>) -> Result<(), XlsxError> {
        worksheet.write_number(row, col, f64::from(*self), format)
    }
}

fn _write(path: &str, data: HashMap<u32, &PyList>, _highlight: Option<&PyDict>) -> Result<(), XlsxError> {
    let workbook = Workbook::new(path);
    let mut worksheet = workbook.add_worksheet(None)?;

    let format_bold = workbook.add_format().set_bold();

    for (row_index, values) in data {

        let mut col_idx: u16 = 0;

        for value in values {
            col_idx += 1;
            let row_format= match &row_index {
                0 => Some(&format_bold),
                _ => None
                };
            match value.extract::<String>() {
                Ok(x) => x.write_value(&mut worksheet, row_index.clone(), &col_idx -1, row_format)?,
                Err(_) => { }
                }
        }
    }
    workbook.close()
    }

Solution

  • This is mostly a pyo3 API issue, and I don't think pyo3 has built-in "multiextract" though I'm not ultra familiar with it, so it may.

    However, first since you don't care about the Err clause you could simplify your code by simply chaining if let statements, they're syntactic sugar but for unary or binary boolean conditions they're really convenient e.g.

    if let Ok(x) = value.extract::<String>() {
        x.write_value(...)
    } else if let Ok(x) = value.extract::<f32>() {
        // handle this case and possibly add a bunch more
    } else {
        // handle no case matching (optional if should be ignored)
    }
    

    Second, it looks like pyo3 lets you derive enums, since WriteValue is apparently an internal trait it would make sense to derive the corresponding enum:

    #[derive(FromPyObject)]
    enum Writables {
        #[pyo3(transparent, annotation = "str")]
        String(String),
        #[pyo3(transparent, annotation = "float")]
        Float(f32),
        // put the other cases here
    }
    

    then you can extract to that and match all the variants at once (and handle the "unsupported types" separately).

    In fact at this point the trait is probably unecessary, unless it's used for other stuff, you could just have your write_value method on the enum directly.

    side-note: extracting a python float (which is a double) to an f32 then immediately widening it to an f64 in order to write it out seems... odd. Why not extract an f64 in the first place?