1/*
2 * Copyright (C) 2010 Apple Inc. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 * 1. Redistributions of source code must retain the above copyright
8 *    notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 *    notice, this list of conditions and the following disclaimer in the
11 *    documentation and/or other materials provided with the distribution.
12 *
13 * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23 * THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26#ifndef GenericCallback_h
27#define GenericCallback_h
28
29#include "APIError.h"
30#include "ProcessThrottler.h"
31#include "ShareableBitmap.h"
32#include "WKAPICast.h"
33#include <functional>
34#include <wtf/HashMap.h>
35#include <wtf/PassRefPtr.h>
36#include <wtf/RefCounted.h>
37#include <wtf/RunLoop.h>
38
39namespace WebKit {
40
41class CallbackBase : public RefCounted<CallbackBase> {
42public:
43    enum class Error {
44        None,
45        Unknown,
46        ProcessExited,
47        OwnerWasInvalidated,
48    };
49
50    virtual ~CallbackBase()
51    {
52    }
53
54    uint64_t callbackID() const { return m_callbackID; }
55
56    template<class T>
57    T* as()
58    {
59        if (T::type() == m_type)
60            return static_cast<T*>(this);
61
62        return nullptr;
63    }
64
65    virtual void invalidate(Error) = 0;
66
67protected:
68    struct TypeTag { };
69    typedef const TypeTag* Type;
70
71    explicit CallbackBase(Type type, std::unique_ptr<ProcessThrottler::BackgroundActivityToken> activityToken)
72        : m_type(type)
73        , m_callbackID(generateCallbackID())
74        , m_activityToken(WTF::move(activityToken))
75    {
76    }
77
78private:
79    static uint64_t generateCallbackID()
80    {
81        ASSERT(RunLoop::isMain());
82        static uint64_t uniqueCallbackID = 1;
83        return uniqueCallbackID++;
84    }
85
86    Type m_type;
87    uint64_t m_callbackID;
88    std::unique_ptr<ProcessThrottler::BackgroundActivityToken> m_activityToken;
89};
90
91template<typename... T>
92class GenericCallback : public CallbackBase {
93public:
94    typedef std::function<void (T..., Error)> CallbackFunction;
95
96    static PassRefPtr<GenericCallback> create(CallbackFunction callback, std::unique_ptr<ProcessThrottler::BackgroundActivityToken> activityToken = nullptr)
97    {
98        return adoptRef(new GenericCallback(callback, WTF::move(activityToken)));
99    }
100
101    virtual ~GenericCallback()
102    {
103        ASSERT(!m_callback);
104    }
105
106    void performCallbackWithReturnValue(T... returnValue)
107    {
108        if (!m_callback)
109            return;
110
111        m_callback(returnValue..., Error::None);
112
113        m_callback = nullptr;
114    }
115
116    void performCallback()
117    {
118        performCallbackWithReturnValue();
119    }
120
121    virtual void invalidate(Error error = Error::Unknown) override final
122    {
123        if (!m_callback)
124            return;
125
126        m_callback(typename std::remove_reference<T>::type()..., error);
127
128        m_callback = nullptr;
129    }
130
131private:
132    GenericCallback(CallbackFunction callback, std::unique_ptr<ProcessThrottler::BackgroundActivityToken> activityToken)
133        : CallbackBase(type(), WTF::move(activityToken))
134        , m_callback(callback)
135    {
136    }
137
138    friend class CallbackBase;
139    static Type type()
140    {
141        static TypeTag tag;
142        return &tag;
143    }
144
145    CallbackFunction m_callback;
146};
147
148template<typename APIReturnValueType, typename InternalReturnValueType = typename APITypeInfo<APIReturnValueType>::ImplType>
149static typename GenericCallback<InternalReturnValueType>::CallbackFunction toGenericCallbackFunction(void* context, void (*callback)(APIReturnValueType, WKErrorRef, void*))
150{
151    return [context, callback](InternalReturnValueType returnValue, CallbackBase::Error error) {
152        callback(toAPI(returnValue), error != CallbackBase::Error::None ? toAPI(API::Error::create().get()) : 0, context);
153    };
154}
155
156typedef GenericCallback<> VoidCallback;
157typedef GenericCallback<const Vector<WebCore::IntRect>&, double> ComputedPagesCallback;
158typedef GenericCallback<const ShareableBitmap::Handle&> ImageCallback;
159
160template<typename T>
161void invalidateCallbackMap(HashMap<uint64_t, T>& callbackMap, CallbackBase::Error error)
162{
163    Vector<T> callbacks;
164    copyValuesToVector(callbackMap, callbacks);
165    for (auto& callback : callbacks)
166        callback->invalidate(error);
167
168    callbackMap.clear();
169}
170
171class CallbackMap {
172public:
173    uint64_t put(PassRefPtr<CallbackBase> callback)
174    {
175        ASSERT(!m_map.contains(callback->callbackID()));
176
177        uint64_t callbackID = callback->callbackID();
178        m_map.set(callbackID, callback);
179        return callbackID;
180    }
181
182    template<unsigned I, typename T, typename... U>
183    struct GenericCallbackType {
184        typedef typename GenericCallbackType<I - 1, U..., T>::type type;
185    };
186
187    template<typename... U>
188    struct GenericCallbackType<1, CallbackBase::Error, U...> {
189        typedef GenericCallback<U...> type;
190    };
191
192    template<typename... T>
193    uint64_t put(std::function<void (T...)> function, std::unique_ptr<ProcessThrottler::BackgroundActivityToken> activityToken)
194    {
195        auto callback = GenericCallbackType<sizeof...(T), T...>::type::create(WTF::move(function), WTF::move(activityToken));
196        return put(callback);
197    }
198
199    template<class T>
200    RefPtr<T> take(uint64_t callbackID)
201    {
202        RefPtr<CallbackBase> base = m_map.take(callbackID);
203        if (!base)
204            return nullptr;
205
206        return adoptRef(base.release().leakRef()->as<T>());
207    }
208
209    void invalidate(CallbackBase::Error error)
210    {
211        invalidateCallbackMap(m_map, error);
212    }
213
214private:
215    HashMap<uint64_t, RefPtr<CallbackBase>> m_map;
216};
217
218} // namespace WebKit
219
220#endif // GenericCallback_h
221