1/*
2 * Copyright (c) 2008 - 2011 Apple Inc. All rights reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24#include <smbclient/smbclient.h>
25#include <smbclient/ntstatus.h>
26#include <smbclient/smbclient_internal.h>
27
28#include <algorithm>
29#include <vector>
30#include <cstdlib>
31#include <assert.h>
32#include <string>
33
34#include "lmshare.h"
35#include "memory.hpp"
36#include "rpc_helpers.hpp"
37
38extern "C" {
39#include <dce/dcethread.h>
40}
41
42static idl_void_p_t
43share_memalloc(idl_void_p_t context, idl_size_t sz)
44{
45    rpc_mempool * pool = (rpc_mempool *)context;
46    return pool->alloc(sz);
47}
48
49static void
50share_memfree(idl_void_p_t context, idl_void_p_t ptr)
51{
52    rpc_mempool * pool = (rpc_mempool *)context;
53    return pool->free(ptr);
54}
55
56NET_API_STATUS
57NetShareGetInfo(
58        const char * ServerName,
59        const char * NetName,
60        uint32_t Level,
61        PSHARE_INFO * ShareInfo)
62{
63	WCHAR * serverName = SMBConvertFromUTF8ToUTF16(ServerName, 1024, 0);
64	WCHAR * netName = SMBConvertFromUTF8ToUTF16(NetName, 1024, 0);
65
66    if (!serverName || !netName || !ShareInfo) {
67		if (serverName)
68			free(serverName);
69		if (netName)
70			free(netName);
71
72        return ERROR_INVALID_PARAMETER;
73    }
74
75    rpc_binding binding = make_rpc_binding(ServerName, "srvsvc");
76    if (binding.get() == NULL) {
77        SMBLogInfo("make_rpc_binding failed", ASL_LEVEL_DEBUG);
78		if (serverName) {
79			free(serverName);
80        }
81		if (netName) {
82			free(netName);
83        }
84        return ERROR_INVALID_PARAMETER;
85    }
86
87    rpc_ss_allocator_t allocator;
88
89    NET_API_STATUS api_status = NERR_Success;
90    error_status_t rpc_status = rpc_s_ok;
91
92	memset(&allocator, 0, sizeof(allocator));
93
94    std::pair<rpc_mempool *, SHARE_INFO *> result(
95            allocate_rpc_mempool<SHARE_INFO>());
96
97    allocator.p_allocate = share_memalloc;
98    allocator.p_free = share_memfree;
99    allocator.p_context = (idl_void_p_t)result.first;
100
101    rpc_ss_swap_client_alloc_free_ex(&allocator, &allocator);
102
103    DCETHREAD_TRY
104        api_status = NetrShareGetInfo(
105                binding.get(),
106                const_cast<WCHAR *>(serverName),
107                const_cast<WCHAR *>(netName),
108                Level, result.second, &rpc_status);
109    DCETHREAD_CATCH_ALL(exc)
110        /*
111	 * Unmarshalling a response with an unknown level will throw a
112	 * rpc_x_invalid_tag exception since the unknown level is not defined as
113	 * a union discriminator.
114	 */
115        rpc_status = rpc_exception_status(exc);
116    DCETHREAD_ENDTRY
117
118	free(serverName);
119	free(netName);
120    rpc_ss_swap_client_alloc_free_ex(&allocator, &allocator);
121
122    if (rpc_status != rpc_s_ok) {
123        SMBLogInfo("RPC to srvsrvc gave error %#08x", ASL_LEVEL_ERR, rpc_status);
124        NetApiBufferFree(result.second);
125        return RPC_S_PROTOCOL_ERROR;
126    }
127
128    if (api_status == NERR_Success) {
129        *ShareInfo = result.second;
130    } else {
131        NetApiBufferFree(result.second);
132        *ShareInfo = NULL;
133    }
134
135    return api_status;
136}
137
138NET_API_STATUS
139NetShareEnum(
140        const char * ServerName,
141        uint32_t Level,
142        PSHARE_ENUM_STRUCT * InfoStruct)
143{
144	WCHAR * serverName = SMBConvertFromUTF8ToUTF16(ServerName, 1024, 0);
145    if (!serverName || !InfoStruct) {
146		if (serverName)
147			free(serverName);
148        return ERROR_INVALID_PARAMETER;
149    }
150
151    rpc_binding binding = make_rpc_binding(ServerName, "srvsvc");
152    if (binding.get() == NULL) {
153        SMBLogInfo("make_rpc_binding failed", ASL_LEVEL_DEBUG);
154		if (serverName) {
155			free(serverName);
156        }
157        return ERROR_INVALID_PARAMETER;
158    }
159
160    rpc_ss_allocator_t allocator;
161
162    NET_API_STATUS api_status = NERR_Success;
163    error_status_t rpc_status = rpc_s_ok;
164
165	memset(&allocator, 0, sizeof(allocator));
166
167    std::pair<rpc_mempool *, SHARE_ENUM_STRUCT *> result(
168            allocate_rpc_mempool<SHARE_ENUM_STRUCT>());
169
170    DWORD entries = 0;
171    DWORD resume = 0;
172
173    allocator.p_allocate = share_memalloc;
174    allocator.p_free = share_memfree;
175    allocator.p_context = (idl_void_p_t)result.first;
176
177    rpc_ss_swap_client_alloc_free_ex(&allocator, &allocator);
178
179    result.second->Level = Level;
180
181    /*
182	 * Windows requires a valid pointer for the ShareInfo union. It doesn't matter
183	 * which container type we choose here, since they all have the same binary
184	 * layout.
185	 */
186    result.second->ShareInfo.Level0 =
187        (SHARE_INFO_0_CONTAINER *)result.first->alloc(
188                                sizeof(SHARE_INFO_0_CONTAINER));
189    result.second->ShareInfo.Level0->EntriesRead = 0;
190    result.second->ShareInfo.Level0->Buffer = NULL;
191
192	DCETHREAD_TRY
193		api_status = NetrShareEnum(
194                binding.get(),
195                const_cast<WCHAR *>(serverName),
196                result.second,
197                0xffffffff,
198                &entries,
199                &resume,
200                &rpc_status);
201	DCETHREAD_CATCH_ALL(exc)
202		rpc_status = rpc_exception_status(exc);
203	DCETHREAD_ENDTRY
204
205    rpc_ss_swap_client_alloc_free_ex(&allocator, &allocator);
206
207	free(serverName);
208    if (rpc_status != rpc_s_ok) {
209        SMBLogInfo("RPC to srvsrvc gave error %#08x", ASL_LEVEL_ERR, rpc_status);
210        NetApiBufferFree(result.second);
211        return RPC_S_PROTOCOL_ERROR;
212    }
213
214    if (api_status == NERR_Success) {
215        *InfoStruct = result.second;
216    } else {
217        NetApiBufferFree(result.second);
218        *InfoStruct = NULL;
219    }
220
221    return api_status;
222}
223
224void
225NetApiBufferFree(
226        void * bufptr)
227{
228    rpc_mempool * pool;
229
230    if (!bufptr) {
231        return;
232    }
233
234    pool = (rpc_mempool *)(void *)((uint8_t *)bufptr - rpc_mempool::block_size());
235    pool->~rpc_mempool();
236
237    std::free(pool);
238}
239
240/* vim: set sw=4 ts=4 tw=79 et: */
241