Search code examples
winapidirectxms-media-foundationc++-winrtdxgi

How do I connect the Graphics Capture API to IMFSinkWriter


I am attempting to write some code that connects the Windows::Graphics::Capture API to IMFSinkWriter in order to capture the desktop to an MP4 file. I find that the IMFSinkWriter WriteSample function always returns 0x80070057 and I'm trying to understand why. I suspect there is a somewhat obvious mistake as I am not extremely familiar with COM, WinRT, DirectX, etc. Any ideas?

#include <iostream>
#include <Windows.h>

// XXX workaround bug in platform headers where this has a circular declaration
#include "winrt/base.h"
namespace winrt::impl
{
    template <typename Async>
    auto wait_for(Async const& async, Windows::Foundation::TimeSpan const& timeout);
}
// XXX

#include <dxgi.h>
#include <inspectable.h>
#include <dxgi1_2.h>
#include <d3d11.h>
#include <mfapi.h>
#include <mfidl.h>
#include <mfreadwrite.h>
#include <codecapi.h>
#include <strmif.h>
#include <winrt/Windows.Foundation.h>
#include <winrt/Windows.System.h>
#include <winrt/Windows.Graphics.Capture.h>
#include <windows.graphics.capture.interop.h>
#include <windows.graphics.directx.direct3d11.interop.h>

#pragma comment(lib, "Mfuuid.lib")
#pragma comment(lib, "Mfplat.lib")
#pragma comment(lib, "mfreadwrite.lib")
#pragma comment(lib, "Mf.lib")


winrt::com_ptr<IMFSinkWriter> sinkWriter;
std::chrono::steady_clock::time_point firstFrameTime;
std::chrono::steady_clock::time_point lastFrameTime;
bool recordedFirstFrame = false;

void OnFrameArrived(winrt::Windows::Graphics::Capture::Direct3D11CaptureFramePool const& sender, winrt::Windows::Foundation::IInspectable const &) {
    winrt::Windows::Graphics::Capture::Direct3D11CaptureFrame frame = sender.TryGetNextFrame();
    std::chrono::steady_clock::time_point frameTime = std::chrono::steady_clock::now();
    LONGLONG duration = 0;
    LONGLONG frametime100ns;
    if (!recordedFirstFrame) {
        recordedFirstFrame = true;
        firstFrameTime = frameTime;
        frametime100ns = 0;
    }
    else {
        frametime100ns = std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - firstFrameTime).count() / 100;
        duration = std::chrono::duration_cast<std::chrono::milliseconds>(frameTime - lastFrameTime).count();
    }
    auto surface = frame.Surface();
    auto access = surface.as<Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess>();
    winrt::com_ptr<ID3D11Texture2D> texture;
    winrt::check_hresult(access->GetInterface(winrt::guid_of<ID3D11Texture2D>(), texture.put_void()));

    IMFMediaBuffer* buffer;
    MFCreateDXGISurfaceBuffer(__uuidof(ID3D11Texture2D), texture.get(), 0, FALSE, &buffer);


    IMFSample *sample;
    winrt::check_hresult(MFCreateSample(&sample));
    HRESULT hr = sample->AddBuffer(buffer);
    printf("add buffer! %x\n", hr);

    hr = sample->SetSampleTime(frametime100ns);
    printf("set sample time (%lld) %d\n", frametime100ns, hr);
    hr = sample->SetSampleDuration(duration);
    printf("set sample duration (%lld) %d\n", duration, hr);

    hr = sinkWriter->WriteSample(0 /* video stream index */, sample);
    printf("wrote sample %x\n", hr);

    lastFrameTime = frameTime;
}


int main()
{  
    winrt::init_apartment(winrt::apartment_type::multi_threaded);
    winrt::check_hresult(MFStartup(MF_VERSION, MFSTARTUP_NOSOCKET));

    // get a list of monitor handles
    std::vector<HMONITOR> monitors;
    EnumDisplayMonitors(
        nullptr, nullptr,
        [](HMONITOR hmon, HDC, LPRECT, LPARAM lparam) {
            auto& monitors = *reinterpret_cast<std::vector<HMONITOR>*>(lparam);
            monitors.push_back(hmon);
            return TRUE;
        },
        reinterpret_cast<LPARAM>(&monitors)
     );

     //get GraphicsCaptureItem for first monitor
     auto interop_factory = winrt::get_activation_factory<winrt::Windows::Graphics::Capture::GraphicsCaptureItem, IGraphicsCaptureItemInterop>();
     winrt::Windows::Graphics::Capture::GraphicsCaptureItem item = { nullptr };
     winrt::check_hresult(
        interop_factory->CreateForMonitor(
        monitors[0],
            winrt::guid_of<ABI::Windows::Graphics::Capture::IGraphicsCaptureItem>(),
            winrt::put_abi(item)
        )
     );

     // Create Direct 3D Device
     winrt::com_ptr<ID3D11Device> d3dDevice;
     winrt::check_hresult(D3D11CreateDevice(nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, D3D11_CREATE_DEVICE_BGRA_SUPPORT,
        nullptr, 0, D3D11_SDK_VERSION, d3dDevice.put(), nullptr, nullptr));


     winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice device;
     const auto dxgiDevice = d3dDevice.as<IDXGIDevice>();
     {
        winrt::com_ptr<::IInspectable> inspectable;
        winrt::check_hresult(CreateDirect3D11DeviceFromDXGIDevice(dxgiDevice.get(), inspectable.put()));
        device = inspectable.as<winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice>();
     }


     auto idxgiDevice2 = dxgiDevice.as<IDXGIDevice2>();
     winrt::com_ptr<IDXGIAdapter> adapter;
     winrt::check_hresult(idxgiDevice2->GetParent(winrt::guid_of<IDXGIAdapter>(), adapter.put_void()));
     winrt::com_ptr<IDXGIFactory2> factory;
     winrt::check_hresult(adapter->GetParent(winrt::guid_of<IDXGIFactory2>(), factory.put_void()));

     ID3D11DeviceContext* d3dContext = nullptr;
     d3dDevice->GetImmediateContext(&d3dContext);


    // setup swap chain
    DXGI_SWAP_CHAIN_DESC1 desc = {};
    desc.Width = static_cast<uint32_t>(item.Size().Width);
    desc.Height = static_cast<uint32_t>(item.Size().Height);
    desc.Format = static_cast<DXGI_FORMAT>(winrt::Windows::Graphics::DirectX::DirectXPixelFormat::R16G16B16A16Float);
    desc.BufferUsage = DXGI_USAGE_RENDER_TARGET_OUTPUT;
    desc.SampleDesc.Count = 1;
    desc.SampleDesc.Quality = 0;
    desc.BufferCount = 2;
    desc.Scaling = DXGI_SCALING_STRETCH;
    desc.SwapEffect = DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL;
    desc.AlphaMode = DXGI_ALPHA_MODE_PREMULTIPLIED;
    winrt::com_ptr<IDXGISwapChain1> swapchain;
    winrt::check_hresult(factory->CreateSwapChainForComposition(d3dDevice.get(), &desc, nullptr, swapchain.put()));

    auto framepool = winrt::Windows::Graphics::Capture::Direct3D11CaptureFramePool::CreateFreeThreaded(device, winrt::Windows::Graphics::DirectX::DirectXPixelFormat::R16G16B16A16Float, 2, item.Size());
    auto session = framepool.CreateCaptureSession(item);
    framepool.FrameArrived(OnFrameArrived);

    //Setup MF output stream
    winrt::com_ptr<IMFDXGIDeviceManager> devManager;
    UINT resetToken;
    winrt::check_hresult(MFCreateDXGIDeviceManager(&resetToken, devManager.put()));
    winrt::check_hresult(devManager->ResetDevice(d3dDevice.get(), resetToken));
    winrt::com_ptr<IMFByteStream> outputStream;
    winrt::check_hresult(MFCreateFile(MF_ACCESSMODE_READWRITE, MF_OPENMODE_DELETE_IF_EXIST, MF_FILEFLAGS_NONE, L"C:\\test.mp4", outputStream.put()));

    //configure MF output media type
    winrt::com_ptr<IMFMediaType> videoMediaType;
    //winrt::com_ptr<IMFMediaType> audioMediaType;

    //for video
    winrt::check_hresult(MFCreateMediaType(videoMediaType.put()));
    winrt::check_hresult(videoMediaType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Video));
    winrt::check_hresult(videoMediaType->SetGUID(MF_MT_SUBTYPE, MFVideoFormat_H264));
    winrt::check_hresult(videoMediaType->SetUINT32(MF_MT_AVG_BITRATE, 2000000));
    winrt::check_hresult(videoMediaType->SetUINT32(MF_MT_INTERLACE_MODE, MFVideoInterlace_Progressive));
    winrt::check_hresult(videoMediaType->SetUINT32(MF_MT_MPEG2_PROFILE, eAVEncH264VProfile_Main));
    winrt::check_hresult(videoMediaType->SetUINT32(MF_MT_YUV_MATRIX, MFVideoTransferMatrix_BT601));
    winrt::check_hresult(MFSetAttributeSize(videoMediaType.get(), MF_MT_FRAME_SIZE, item.Size().Width, item.Size().Height));
    winrt::check_hresult(MFSetAttributeRatio(videoMediaType.get(), MF_MT_FRAME_RATE, 30, 1));
    winrt::check_hresult(MFSetAttributeRatio(videoMediaType.get(), MF_MT_PIXEL_ASPECT_RATIO, 1, 1));

    //Creates a streaming writer
    winrt::com_ptr<IMFMediaSink> mp4StreamSink;
    winrt::check_hresult(MFCreateMPEG4MediaSink(outputStream.get(), videoMediaType.get(), NULL, mp4StreamSink.put()));

    //setup MF Input stream
    winrt::com_ptr<IMFMediaType> inputVideoMediaType;

    HRESULT hr = S_OK;
    GUID majortype = { 0 };
    MFRatio par = { 0 };

    hr = videoMediaType->GetMajorType(&majortype);
    if (majortype != MFMediaType_Video)
    {
        throw new winrt::hresult_invalid_argument();
    }
    // Create a new media type and copy over all of the items.
    // This ensures that extended color information is retained.
    winrt::check_hresult(MFCreateMediaType(inputVideoMediaType.put()));
    winrt::check_hresult(videoMediaType->CopyAllItems(inputVideoMediaType.get()));
    // Set the subtype.
    winrt::check_hresult(inputVideoMediaType->SetGUID(MF_MT_SUBTYPE, MFVideoFormat_ARGB32));
    // Uncompressed means all samples are independent.
    winrt::check_hresult(inputVideoMediaType->SetUINT32(MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE));
    // Fix up PAR if not set on the original type.
    hr = MFGetAttributeRatio(
        inputVideoMediaType.get(),
        MF_MT_PIXEL_ASPECT_RATIO,
        (UINT32*)&par.Numerator,
        (UINT32*)&par.Denominator
    );
    // Default to square pixels.
    if (FAILED(hr))
    {
        winrt::check_hresult(MFSetAttributeRatio(
            inputVideoMediaType.get(),
            MF_MT_PIXEL_ASPECT_RATIO,
            1, 1
        ));
    }

    winrt::check_hresult(MFSetAttributeSize(inputVideoMediaType.get(), MF_MT_FRAME_SIZE, item.Size().Width, item.Size().Height));
    inputVideoMediaType->SetUINT32(MF_MT_VIDEO_ROTATION, MFVideoRotationFormat_0); //XXX where do we get the rotation from?

    winrt::com_ptr<IMFAttributes> attributes;
    winrt::check_hresult(MFCreateAttributes(attributes.put(), 6));
    winrt::check_hresult(attributes->SetGUID(MF_TRANSCODE_CONTAINERTYPE, MFTranscodeContainerType_MPEG4));
    winrt::check_hresult(attributes->SetUINT32(MF_READWRITE_ENABLE_HARDWARE_TRANSFORMS, 1));
    winrt::check_hresult(attributes->SetUINT32(MF_MPEG4SINK_MOOV_BEFORE_MDAT, 1));
    winrt::check_hresult(attributes->SetUINT32(MF_LOW_LATENCY, FALSE)); ///XXX should we?
    winrt::check_hresult(attributes->SetUINT32(MF_SINK_WRITER_DISABLE_THROTTLING, FALSE)); //XX shuold we?
    // Add device manager to attributes. This enables hardware encoding.
    winrt::check_hresult(attributes->SetUnknown(MF_SINK_WRITER_D3D_MANAGER, devManager.get()));

    //winrt::com_ptr<IMFSinkWriter> sinkWriter;
    winrt::check_hresult(MFCreateSinkWriterFromMediaSink(mp4StreamSink.get(), attributes.get(), sinkWriter.put()));
    sinkWriter->SetInputMediaType(0, inputVideoMediaType.get(), nullptr);
 
    winrt::com_ptr<ICodecAPI> encoder;
    sinkWriter->GetServiceForStream(0 /* video stream index */, GUID_NULL, IID_PPV_ARGS(encoder.put()));
    VARIANT var;
    VariantInit(&var);
    var.vt = VT_UI4;
    var.ulVal = eAVEncCommonRateControlMode_Quality;
    winrt::check_hresult(encoder->SetValue(&CODECAPI_AVEncCommonRateControlMode, &var));
    var.ulVal = 70;
    winrt::check_hresult(encoder->SetValue(&CODECAPI_AVEncCommonQuality, &var));

    winrt::check_hresult(sinkWriter->BeginWriting());
    session.StartCapture();

    std::cout << "Hello World!\n";

    Sleep(1000);

    session.Close();
    sinkWriter->Flush(0);
    sinkWriter->Finalize();
}

Solution

  • I was able to track down the problem. The above code had two issues:

    1. Need to call SetCurrentLength() on the IMFMediaBuffer object. It seems silly since the way to get the length is to get the IMF2DBuffer interface from the IMFMediaBuffer object and call GetContiguousLength(), but it works.
    2. Taking the texture straight from the OnFrameArrived() callback and passing it into the IMF sink is also wrong. This will exhaust the framepool (which is declared as having 2 frames) and hang the encoder. One possible solution is to copy the data out into a new texture before passing it to the encoder.