Search code examples
c#.netcomwebbrowser-controlcom-interop

WebBrowserSite: how to call a private COM interface method in a derived class?


Here is the challenge. I'm deriving from the Framework's WebBrowserSite class. An instance of my derived class, ImprovedWebBrowserSite, is returned via WebBrowser.CreateWebBrowserSiteBase, which I override in my derived version of the WebBrowser class - specifically to provide a custom site object. The Framework's WebBrowser implementation further passes it to down to the underlying unmanaged WebBrowser ActiveX control.

So far, I've managed to override IDocHostUIHandler in my ImprovedWebBrowserSite implementation (like this). I'm now looking for more core COM interfaces, like IOleClientSite, which I want to pass-through to WebBrowserSite. All of them are exposed to COM with ComImport, but declared as private or internal by the Framework's implementation of WebBrowserSite/UnsafeNativeMethods. Thus, I cannot explicitly re-implement them in the derived class. I have to define my own versions, like I did with IDocHostUIHandler.

So, the question is, how do I call a method of a private or internal COM interface defined in WebBrowserSite, from my derived class? For example, I want to call IOleClientSite.GetContainer. I can use reflection (like this), but that would be the last resort, second to re-implementing WebBrowser from scratch.

My thinking is, because the Framework's private UnsafeNativeMethods.IOleClientSite and my own ImprovedWebBrowserSite.IOleClientSite are both COM interfaces, declared with the ComImport attribute, the same GUID and identical method signatures. There's COM Type Equivalence in .NET 4.0+, so there has to be a way to do it without reflection.

[UPDATE] Now that I've got a solution, I believe it opens some new and interesting possibilities in customizing the WinForms version of WebBrowser control.

This version of the question was created after my initial attempt to formulate the problem in a more abstract form was dubbed misleading by a commentator. The comment has been removed later, but I decided to keep both versions.

Why did I not want to use reflection to solve this problem? For a few reasons:

  • Dependency on the actual symbolic names of the internal or private methods, as given by the implementers of WebBrowserSite, unlike with a COM interface, which is about the binary v-table contract.

  • Bulky reflection code. E.g., consider calling the base's private TranslateAccelerator via Type.InvokeMember, and I have ~20 methods like that to call.

  • Although less important, efficiency: a late-bound call via reflection is always less efficient than a direct call to a COM interface method via v-table.


Solution

  • Finally, I believe I've solved the problem using Marshal.CreateAggregatedObject, with some help from @EricBrown.

    Here's the code that makes possible customizing WebBrowserSite OLE interfaces, using IOleClientSite as an example, calling the private COM-visible implementation of WebBrowserSite. It can be extended to other interfaces, e.g. IDocHostUIHandler.

    using System;
    using System.Diagnostics;
    using System.Linq;
    using System.Runtime.InteropServices;
    using System.Windows.Forms;
    
    namespace CustomWebBrowser
    {
        public partial class MainForm : Form
        {
            public MainForm()
            {
                InitializeComponent();
            }
    
            private void MainForm_Load(object sender, EventArgs e)
            {
                var wb = new ImprovedWebBrowser();
                wb.Dock = DockStyle.Fill;
                this.Controls.Add(wb);
                wb.Visible = true;
                wb.DocumentText = "<b>Hello from ImprovedWebBrowser!</b>";
            }
        }
    
        // ImprovedWebBrowser with custom pass-through IOleClientSite 
        public class ImprovedWebBrowser: WebBrowser
        {
            // provide custom WebBrowserSite,
            // where we override IOleClientSite and call the base implementation
            protected override WebBrowserSiteBase CreateWebBrowserSiteBase()
            {
                return new ImprovedWebBrowserSite(this);
            }
    
            // IOleClientSite
            [ComImport(), Guid("00000118-0000-0000-C000-000000000046")]
            [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
            public interface IOleClientSite
            {
                void SaveObject();
    
                [return: MarshalAs(UnmanagedType.Interface)]
                object GetMoniker(
                    [In, MarshalAs(UnmanagedType.U4)] int dwAssign,
                    [In, MarshalAs(UnmanagedType.U4)] int dwWhichMoniker);
    
                [PreserveSig]
                int GetContainer([Out] out IntPtr ppContainer);
    
                void ShowObject();
    
                void OnShowWindow([In, MarshalAs(UnmanagedType.I4)] int fShow);
    
                void RequestNewObjectLayout();
            }
    
            // ImprovedWebBrowserSite
            protected class ImprovedWebBrowserSite :
                WebBrowserSite,
                IOleClientSite,
                ICustomQueryInterface,
                IDisposable
            {
                IOleClientSite _baseIOleClientSite;
                IntPtr _unkOuter;
                IntPtr _unkInnerAggregated;
                Inner _inner;
    
                #region Inner
                // Inner as aggregated object
                class Inner :
                    ICustomQueryInterface,
                    IDisposable
                {
                    object _outer;
                    Type[] _interfaces;
    
                    public Inner(object outer)
                    {
                        _outer = outer;
                        // the base's private COM interfaces are here
                        _interfaces = _outer.GetType().BaseType.GetInterfaces(); 
                    }
    
                    public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
                    {
                        if (_outer != null)
                        {
                            var ifaceGuid = iid;
                            var iface = _interfaces.FirstOrDefault((t) => t.GUID == ifaceGuid);
                            if (iface != null)
                            {
                                var unk = Marshal.GetComInterfaceForObject(_outer, iface, CustomQueryInterfaceMode.Ignore);
                                if (unk != IntPtr.Zero)
                                {
                                    ppv = unk;
                                    return CustomQueryInterfaceResult.Handled;
                                }
                            }
                        }
                        ppv = IntPtr.Zero;
                        return CustomQueryInterfaceResult.Failed;
                    }
    
                    ~Inner()
                    {
                        // need to work out the reference counting for GC to work correctly
                        Debug.Print("Inner object finalized.");
                    }
    
                    public void Dispose()
                    {
                        _outer = null;
                        _interfaces = null;
                    }
                }
                #endregion
    
                // constructor
                public ImprovedWebBrowserSite(WebBrowser host):
                    base(host)
                {
                    // get the CCW object for this
                    _unkOuter = Marshal.GetIUnknownForObject(this);
                    Marshal.AddRef(_unkOuter);
                    try
                    {
                        // aggregate the CCW object with the helper Inner object
                        _inner = new Inner(this);
                        _unkInnerAggregated = Marshal.CreateAggregatedObject(_unkOuter, _inner);
    
                        // turn private WebBrowserSiteBase.IOleClientSite into our own IOleClientSite
                        _baseIOleClientSite = (IOleClientSite)Marshal.GetTypedObjectForIUnknown(_unkInnerAggregated, typeof(IOleClientSite));
                    }
                    finally
                    {
                        Marshal.Release(_unkOuter);
                    }
                }
    
                ~ImprovedWebBrowserSite()
                {
                    // need to work out the reference counting for GC to work correctly
                    Debug.Print("ImprovedClass finalized.");
                }
    
                public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
                {
                    if (iid == typeof(IOleClientSite).GUID)
                    {
                        // CustomQueryInterfaceMode.Ignore is to avoid infinite loop during QI.
                        ppv = Marshal.GetComInterfaceForObject(this, typeof(IOleClientSite), CustomQueryInterfaceMode.Ignore);
                        return CustomQueryInterfaceResult.Handled;
                    }
                    ppv = IntPtr.Zero;
                    return CustomQueryInterfaceResult.NotHandled;
                }
    
                void IDisposable.Dispose()
                {
                    base.Dispose();
    
                    // we may have recicular references to itself
                    _baseIOleClientSite = null;
    
                    if (_inner != null)
                    {
                        _inner.Dispose();
                        _inner = null;
                    }
    
                    if (_unkInnerAggregated != IntPtr.Zero)
                    {
                        Marshal.Release(_unkInnerAggregated);
                        _unkInnerAggregated = IntPtr.Zero;
                    }
    
                    if (_unkOuter != IntPtr.Zero)
                    {
                        Marshal.Release(_unkOuter);
                        _unkOuter = IntPtr.Zero;
                    }
                }
    
                #region IOleClientSite
                // IOleClientSite
                public void SaveObject()
                {
                    Debug.Print("IOleClientSite.SaveObject");
                    _baseIOleClientSite.SaveObject();
                }
    
                public object GetMoniker(int dwAssign, int dwWhichMoniker)
                {
                    Debug.Print("IOleClientSite.GetMoniker");
                    return _baseIOleClientSite.GetMoniker(dwAssign, dwWhichMoniker);
                }
    
                public int GetContainer(out IntPtr ppContainer)
                {
                    Debug.Print("IOleClientSite.GetContainer");
                    return _baseIOleClientSite.GetContainer(out ppContainer);
                }
    
                public void ShowObject()
                {
                    Debug.Print("IOleClientSite.ShowObject");
                    _baseIOleClientSite.ShowObject();
                }
    
                public void OnShowWindow(int fShow)
                {
                    Debug.Print("IOleClientSite.OnShowWindow");
                    _baseIOleClientSite.OnShowWindow(fShow);
                }
    
                public void RequestNewObjectLayout()
                {
                    Debug.Print("IOleClientSite.RequestNewObjectLayout");
                    _baseIOleClientSite.RequestNewObjectLayout();
                }
                #endregion
            }
        }
    }