Search code examples
floating-pointvhdl

Multiplying two half-precision floats in procedural VHDL


I'm trying to implement a procedural half-float multiply function in VHDL. As things stand, I have this:

function "*" (l,r: half_float) return half_float is
    variable l_full_mantissa, r_full_mantissa: unsigned(l.mantissa'length downto 0);
    variable multiplied_mantissae: unsigned((l.mantissa'length + 1) * 2 - 1 downto 0);
    variable new_exponent: unsigned(fp16_exponent_len - 1 downto 0);
    begin

        if (l = fp16_zero or r = fp16_zero) then
            return fp16_zero;
        end if;
        report "L-Mantissa: " & to_string(l.mantissa);
        report "R-Mantissa: " & to_string(r.mantissa);
        report "L-Exponent: " & integer'image(to_integer(unsigned(l.exponent)));
        report "R-Exponent: " & integer'image(to_integer(unsigned(r.exponent)));
        new_exponent := unsigned(l.exponent) + (unsigned(r.exponent) - to_unsigned(fp16_exponent_bias, get_width_for_unsigned(fp16_exponent_bias))); -- Subtract the bias to prevent double counting
        report "N-Exponent: " & integer'image(to_integer(unsigned(new_exponent)));

        -- Prepend the leading 1s
        l_full_mantissa := unsigned('1' & l.mantissa);
        r_full_mantissa := unsigned('1' & r.mantissa);

        report integer'image(to_integer(l_full_mantissa));
        report integer'image(to_integer(r_full_mantissa));

        -- Multiply the mantissae
        multiplied_mantissae := l_full_mantissa * r_full_mantissa;
        multiplied_mantissae := multiplied_mantissae sll 2; -- <-- not sure about this

        report integer'image(to_integer(multiplied_mantissae));
        report to_string(multiplied_mantissae);

        return (l.sign xor r.sign, std_logic_vector(new_exponent), std_logic_vector(multiplied_mantissae(multiplied_mantissae'high downto multiplied_mantissae'high - fp16_mantissa_len + 1)));
    end function;

And this seems to work for all of these test cases except for the last:

        report "Testing 5x2";
        assert to_float(five) * to_float(two) = to_float(ten);

        report "Testing 2x42";
        assert to_float(two) * to_float(forty_two) = to_float(eighty_four);

        report "Testing 2.5x4";
        assert to_float(two_point_five_slv) * to_float(four) = to_float(ten);

        report "Testing -4x2.5";
        assert to_float(minus_four) * to_float(two_point_five_slv) = to_float(minus_ten);

        report "Testing 0.25x0.25";
        assert to_float(one_quarter) * to_float(one_quarter) = to_float(one_sixteenth);

        report "Testing 0.45x4" severity note;
        assert to_float(point_four_five) * to_float(four) = to_float(one_point_eight);

        report "Testing 0.0005x640";
        assert to_float(point_o_o_o_five) * to_float(six_forty) = to_float(point_three_two);

        report "Testing -0.96x-0.96";
        assert to_float(minus_point_nine_six) * to_float(minus_point_nine_six) = to_float(point_nine_two_one_four);

For -0.96x-0.96 I get 0.4214 rather than 0.9216.

There's an sll 2 in my function that I'm not entirely sure why it's needed. I'm also not sure what exactly makes my last case different from the others. Consequently, I expect that's where I'm missing something.

Any ideas?


Solution

  • This isn't complete as it doesn't address all edge-cases, infinities or rounding modes, it also doesn't support subnormals. However, this seems to work:

    1. Check if multiplying by zero, return zero if so.
    2. Check if multiplying by NaN, return zero if so.
    3. Calculate an interim exponent by adding the two exponents. We need to subtract the bias from this to avoid it being double-added.
    4. Prepend the inferred 1s to the mantissae.
    5. Multiply the mantissae
    6. Truncate this mantissae to the length of your stored mantissa plus 2. (x.yyyy... * y.zzzz... always equals aa.bbbb... if yyyy... and zzzz... are the same length)
    7. The 2 high bits will always be in {11, 10, 01} so we check the high bit and right shift by 1 if it's set. We must also increment the exponent if we do this.
    8. We then take the low n-2 bits as our new mantissa.
    9. The resultant sing is gained with l.sign xor r.sign.
    10. Concatenate the whole lot to get the result.

    In VHDL, this looks like:

        function "*" (l,r: float) return float is
        variable l_full_mantissa, r_full_mantissa: unsigned(l.mantissa'length downto 0);
        variable multiplied_mantissae: unsigned((l.mantissa'length + 1) * 2 - 1 downto 0);
        variable new_exponent: unsigned(fp_exponent_len - 1 downto 0);
        variable truncated_mantissa: std_logic_vector(l_full_mantissa'length downto 0);
        begin
    
            -- Check edge conditions
    
            if (l = fp_zero or r = fp_zero) then
                return fp_zero;
            end if;
    
            if (l = nan or r = nan) then
                return nan;
            end if;
    
            -- Calculate the interim exponent
            new_exponent := (unsigned(l.exponent) + unsigned(r.exponent)) - to_unsigned(fp_exponent_bias, get_width_for_unsigned(fp_exponent_bias)); -- Subtract the bias to prevent double counting
    
            -- Prepend the leading 1s
            l_full_mantissa := unsigned('1' & l.mantissa);
            r_full_mantissa := unsigned('1' & r.mantissa);
    
            -- Multiply the mantissae
            multiplied_mantissae := l_full_mantissa * r_full_mantissa;
    
            -- Truncate resulting mantissa to 24 bits
            truncated_mantissa := std_logic_vector(multiplied_mantissae(multiplied_mantissae'high downto multiplied_mantissae'high - l_full_mantissa'length));
    
            -- Renormalise the result
            -- Result will always be xx.yyyyyyyy...
            -- xx will always be 11, 10 or 01, so if the high bit is 1 we need to right shift once
            if (truncated_mantissa(truncated_mantissa'high) = '1') then
                truncated_mantissa := truncated_mantissa srl 1;
                new_exponent := new_exponent + 1;
            end if;
    
            report  to_string((l.sign xor r.sign) & std_logic_vector(new_exponent) & std_logic_vector(truncated_mantissa(truncated_mantissa'high - 2 downto 0)));
            return ((l.sign xor r.sign), std_logic_vector(new_exponent), std_logic_vector(truncated_mantissa(truncated_mantissa'high - 2 downto 0)));
        end function;