1/*
2 * Copyright (C) 2007, 2010 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 *    notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 *    notice, this list of conditions and the following disclaimer in the
11 *    documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
17 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#ifndef COMPtr_h
27#define COMPtr_h
28
29#ifndef NOMINMAX
30#define NOMINMAX
31#endif
32
33#include <unknwn.h>
34#include <wtf/Assertions.h>
35#include <wtf/HashTraits.h>
36
37#if !OS(WINCE)
38#include <guiddef.h>
39#endif
40
41typedef long HRESULT;
42
43// FIXME: Should we put this into the WebCore namespace and use "using" on it
44// as we do with things in WTF?
45
46enum AdoptCOMTag { AdoptCOM };
47enum QueryTag { Query };
48enum CreateTag { Create };
49
50template<typename T> class COMPtr {
51public:
52    COMPtr() : m_ptr(0) { }
53    COMPtr(T* ptr) : m_ptr(ptr) { if (m_ptr) m_ptr->AddRef(); }
54    COMPtr(AdoptCOMTag, T* ptr) : m_ptr(ptr) { }
55    COMPtr(const COMPtr& o) : m_ptr(o.m_ptr) { if (T* ptr = m_ptr) ptr->AddRef(); }
56
57    COMPtr(QueryTag, IUnknown* ptr) : m_ptr(copyQueryInterfaceRef(ptr)) { }
58    template<typename U> COMPtr(QueryTag, const COMPtr<U>& ptr) : m_ptr(copyQueryInterfaceRef(ptr.get())) { }
59
60    COMPtr(CreateTag, const IID& clsid) : m_ptr(createInstance(clsid)) { }
61
62    // Hash table deleted values, which are only constructed and never copied or destroyed.
63    COMPtr(WTF::HashTableDeletedValueType) : m_ptr(hashTableDeletedValue()) { }
64    bool isHashTableDeletedValue() const { return m_ptr == hashTableDeletedValue(); }
65
66    ~COMPtr() { if (m_ptr) m_ptr->Release(); }
67
68    T* get() const { return m_ptr; }
69
70    void clear();
71    T* leakRef();
72
73    T& operator*() const { return *m_ptr; }
74    T* operator->() const { return m_ptr; }
75
76    T** operator&() { ASSERT(!m_ptr); return &m_ptr; }
77
78    bool operator!() const { return !m_ptr; }
79
80    // This conversion operator allows implicit conversion to bool but not to other integer types.
81    typedef T* (COMPtr::*UnspecifiedBoolType)() const;
82    operator UnspecifiedBoolType() const { return m_ptr ? &COMPtr::get : 0; }
83
84    COMPtr& operator=(const COMPtr&);
85    COMPtr& operator=(T*);
86    template<typename U> COMPtr& operator=(const COMPtr<U>&);
87
88    void query(IUnknown* ptr) { adoptRef(copyQueryInterfaceRef(ptr)); }
89    template<typename U> void query(const COMPtr<U>& ptr) { query(ptr.get()); }
90
91    void create(const IID& clsid) { adoptRef(createInstance(clsid)); }
92
93    template<typename U> HRESULT copyRefTo(U**);
94    void adoptRef(T*);
95
96private:
97    static T* copyQueryInterfaceRef(IUnknown*);
98    static T* createInstance(const IID& clsid);
99    static T* hashTableDeletedValue() { return reinterpret_cast<T*>(-1); }
100
101    T* m_ptr;
102};
103
104template<typename T> inline COMPtr<T> adoptCOM(T *ptr)
105{
106    return COMPtr<T>(AdoptCOM, ptr);
107}
108
109template<typename T> inline void COMPtr<T>::clear()
110{
111    if (T* ptr = m_ptr) {
112        m_ptr = 0;
113        ptr->Release();
114    }
115}
116
117template<typename T> inline T* COMPtr<T>::leakRef()
118{
119    T* ptr = m_ptr;
120    m_ptr = 0;
121    return ptr;
122}
123
124template<typename T> inline T* COMPtr<T>::createInstance(const IID& clsid)
125{
126    T* result;
127    if (FAILED(CoCreateInstance(clsid, 0, CLSCTX_ALL, __uuidof(result), reinterpret_cast<void**>(&result))))
128        return 0;
129    return result;
130}
131
132template<typename T> inline T* COMPtr<T>::copyQueryInterfaceRef(IUnknown* ptr)
133{
134    if (!ptr)
135        return 0;
136    T* result;
137    if (FAILED(ptr->QueryInterface(&result)))
138        return 0;
139    return result;
140}
141
142template<typename T> template<typename U> inline HRESULT COMPtr<T>::copyRefTo(U** ptr)
143{
144    if (!ptr)
145        return E_POINTER;
146    *ptr = m_ptr;
147    if (m_ptr)
148        m_ptr->AddRef();
149    return S_OK;
150}
151
152template<typename T> inline void COMPtr<T>::adoptRef(T *ptr)
153{
154    if (m_ptr)
155        m_ptr->Release();
156    m_ptr = ptr;
157}
158
159template<typename T> inline COMPtr<T>& COMPtr<T>::operator=(const COMPtr<T>& o)
160{
161    T* optr = o.get();
162    if (optr)
163        optr->AddRef();
164    T* ptr = m_ptr;
165    m_ptr = optr;
166    if (ptr)
167        ptr->Release();
168    return *this;
169}
170
171template<typename T> template<typename U> inline COMPtr<T>& COMPtr<T>::operator=(const COMPtr<U>& o)
172{
173    T* optr = o.get();
174    if (optr)
175        optr->AddRef();
176    T* ptr = m_ptr;
177    m_ptr = optr;
178    if (ptr)
179        ptr->Release();
180    return *this;
181}
182
183template<typename T> inline COMPtr<T>& COMPtr<T>::operator=(T* optr)
184{
185    if (optr)
186        optr->AddRef();
187    T* ptr = m_ptr;
188    m_ptr = optr;
189    if (ptr)
190        ptr->Release();
191    return *this;
192}
193
194template<typename T, typename U> inline bool operator==(const COMPtr<T>& a, const COMPtr<U>& b)
195{
196    return a.get() == b.get();
197}
198
199template<typename T, typename U> inline bool operator==(const COMPtr<T>& a, U* b)
200{
201    return a.get() == b;
202}
203
204template<typename T, typename U> inline bool operator==(T* a, const COMPtr<U>& b)
205{
206    return a == b.get();
207}
208
209template<typename T, typename U> inline bool operator!=(const COMPtr<T>& a, const COMPtr<U>& b)
210{
211    return a.get() != b.get();
212}
213
214template<typename T, typename U> inline bool operator!=(const COMPtr<T>& a, U* b)
215{
216    return a.get() != b;
217}
218
219template<typename T, typename U> inline bool operator!=(T* a, const COMPtr<U>& b)
220{
221    return a != b.get();
222}
223
224namespace WTF {
225
226    template<typename P> struct HashTraits<COMPtr<P> > : GenericHashTraits<COMPtr<P> > {
227        static const bool emptyValueIsZero = true;
228        static void constructDeletedValue(COMPtr<P>& slot) { new (&slot) COMPtr<P>(HashTableDeletedValue); }
229        static bool isDeletedValue(const COMPtr<P>& value) { return value.isHashTableDeletedValue(); }
230    };
231
232    template<typename P> struct PtrHash<COMPtr<P> > : PtrHash<P*> {
233        using PtrHash<P*>::hash;
234        static unsigned hash(const COMPtr<P>& key) { return hash(key.get()); }
235        using PtrHash<P*>::equal;
236        static bool equal(const COMPtr<P>& a, const COMPtr<P>& b) { return a == b; }
237        static bool equal(P* a, const COMPtr<P>& b) { return a == b; }
238        static bool equal(const COMPtr<P>& a, P* b) { return a == b; }
239    };
240
241    template<typename P> struct DefaultHash<COMPtr<P> > { typedef PtrHash<COMPtr<P> > Hash; };
242}
243
244#endif
245