1/***********************************************************************
2 * Copyright (c) 2009, Secure Endpoints Inc.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 *
9 * - Redistributions of source code must retain the above copyright
10 *   notice, this list of conditions and the following disclaimer.
11 *
12 * - Redistributions in binary form must reproduce the above copyright
13 *   notice, this list of conditions and the following disclaimer in
14 *   the documentation and/or other materials provided with the
15 *   distribution.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
20 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
21 * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
22 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
25 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
26 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
28 * OF THE POSSIBILITY OF SUCH DAMAGE.
29 *
30 **********************************************************************/
31
32#include <config.h>
33#include <windows.h>
34#include <dlfcn.h>
35#include <strsafe.h>
36
37#define ERR_STR_LEN 256
38
39static volatile LONG dlfcn_tls = TLS_OUT_OF_INDEXES;
40
41static DWORD get_tl_error_slot(void)
42{
43    if (dlfcn_tls == TLS_OUT_OF_INDEXES) {
44        DWORD slot = TlsAlloc();
45        DWORD old_slot;
46
47        if (slot == TLS_OUT_OF_INDEXES)
48            return dlfcn_tls;
49
50        if ((old_slot = InterlockedCompareExchange(&dlfcn_tls, slot,
51                                                   TLS_OUT_OF_INDEXES)) !=
52            TLS_OUT_OF_INDEXES) {
53
54            /* Lost a race */
55            TlsFree(slot);
56            return old_slot;
57        } else {
58            return slot;
59        }
60    }
61
62    return dlfcn_tls;
63}
64
65static void set_error(const char * e)
66{
67    char * s;
68    char * old_s;
69    size_t len;
70
71    DWORD slot = get_tl_error_slot();
72
73    if (slot == TLS_OUT_OF_INDEXES)
74        return;
75
76    len = strlen(e) * sizeof(char) + sizeof(char);
77    s = LocalAlloc(LMEM_FIXED, len);
78    if (s == NULL)
79        return;
80
81    old_s = (char *) TlsGetValue(slot);
82    TlsSetValue(slot, (LPVOID) s);
83
84    if (old_s != NULL)
85        LocalFree(old_s);
86}
87
88static void set_error_from_last(void) {
89    DWORD slot = get_tl_error_slot();
90    char * s = NULL;
91    char * old_s;
92
93    if (slot == TLS_OUT_OF_INDEXES)
94        return;
95
96    FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER,
97		  0, GetLastError(), 0,
98		  (LPTSTR) &s, 0,
99		  NULL);
100    if (s == NULL)
101        return;
102
103    old_s = (char *) TlsGetValue(slot);
104    TlsSetValue(slot, (LPVOID) s);
105
106    if (old_s != NULL)
107        LocalFree(old_s);
108}
109
110ROKEN_LIB_FUNCTION int ROKEN_LIB_CALL
111dlclose(void * vhm)
112{
113    BOOL brv;
114
115    brv = FreeLibrary((HMODULE) vhm);
116    if (!brv) {
117	set_error_from_last();
118    }
119    return !brv;
120}
121
122ROKEN_LIB_FUNCTION char  * ROKEN_LIB_CALL
123dlerror(void)
124{
125    DWORD slot = get_tl_error_slot();
126
127    if (slot == TLS_OUT_OF_INDEXES)
128        return NULL;
129
130    return (char *) TlsGetValue(slot);
131}
132
133ROKEN_LIB_FUNCTION void  * ROKEN_LIB_CALL
134dlopen(const char *fn, int flags)
135{
136    HMODULE hm;
137    UINT    old_error_mode;
138
139    /* We don't support dlopen(0, ...) on Windows.*/
140    if ( fn == NULL ) {
141	set_error("Not implemented");
142	return NULL;
143    }
144
145    old_error_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
146
147    hm = LoadLibrary(fn);
148
149    if (hm == NULL) {
150	set_error_from_last();
151    }
152
153    SetErrorMode(old_error_mode);
154
155    return (void *) hm;
156}
157
158ROKEN_LIB_FUNCTION DLSYM_RET_TYPE ROKEN_LIB_CALL
159dlsym(void * vhm, const char * func_name)
160{
161    HMODULE hm = (HMODULE) vhm;
162
163    return (DLSYM_RET_TYPE)(ULONG_PTR)GetProcAddress(hm, func_name);
164}
165
166