ServerSocket.cxx 8.92 KB
Newer Older
1
/*
Max Kellermann's avatar
Max Kellermann committed
2
 * Copyright (C) 2003-2014 The Music Player Daemon Project
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
 * http://www.musicpd.org
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

#include "config.h"
21 22 23 24 25

#ifdef HAVE_STRUCT_UCRED
#define _GNU_SOURCE 1
#endif

26
#include "ServerSocket.hxx"
27 28
#include "system/SocketUtil.hxx"
#include "system/SocketError.hxx"
29
#include "event/SocketMonitor.hxx"
30
#include "system/Resolver.hxx"
31
#include "system/fd_util.h"
32 33
#include "fs/AllocatedPath.hxx"
#include "fs/FileSystem.hxx"
34
#include "util/Alloc.hxx"
35 36
#include "util/Error.hxx"
#include "util/Domain.hxx"
37
#include "Log.hxx"
38

39
#include <string>
40
#include <algorithm>
41

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <string.h>
#include <unistd.h>
#include <stdlib.h>
#include <assert.h>

#ifdef WIN32
#include <ws2tcpip.h>
#include <winsock.h>
#else
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <netdb.h>
#endif

#define DEFAULT_PORT	6600

62
class OneServerSocket final : private SocketMonitor {
63
	ServerSocket &parent;
64

65
	const unsigned serial;
66

67
	AllocatedPath path;
68 69

	size_t address_length;
70 71
	struct sockaddr *address;

72
public:
73
	OneServerSocket(EventLoop &_loop, ServerSocket &_parent,
74
			unsigned _serial,
75 76
			const struct sockaddr *_address,
			size_t _address_length)
77 78
		:SocketMonitor(_loop),
		 parent(_parent), serial(_serial),
79
		 path(AllocatedPath::Null()),
80
		 address_length(_address_length),
81
		 address((sockaddr *)xmemdup(_address, _address_length))
82 83 84 85 86 87 88 89 90
	{
		assert(_address != nullptr);
		assert(_address_length > 0);
	}

	OneServerSocket(const OneServerSocket &other) = delete;
	OneServerSocket &operator=(const OneServerSocket &other) = delete;

	~OneServerSocket() {
91
		free(address);
92 93 94

		if (IsDefined())
			Close();
95 96
	}

97 98 99 100
	unsigned GetSerial() const {
		return serial;
	}

101 102
	void SetPath(AllocatedPath &&_path) {
		assert(path.IsNull());
103

104
		path = std::move(_path);
105 106
	}

107
	bool Open(Error &error);
108

109
	using SocketMonitor::IsDefined;
110
	using SocketMonitor::Close;
111

112 113 114 115
	gcc_pure
	std::string ToString() const {
		return sockaddr_to_string(address, address_length);
	}
116

117 118 119 120
	void SetFD(int _fd) {
		SocketMonitor::Open(_fd);
		SocketMonitor::ScheduleRead();
	}
121 122

	void Accept();
123 124

private:
125
	virtual bool OnSocketReady(unsigned flags) override;
126 127
};

128
static constexpr Domain server_socket_domain("server_socket");
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154

static int
get_remote_uid(int fd)
{
#ifdef HAVE_STRUCT_UCRED
	struct ucred cred;
	socklen_t len = sizeof (cred);

	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cred, &len) < 0)
		return 0;

	return cred.uid;
#else
#ifdef HAVE_GETPEEREID
	uid_t euid;
	gid_t egid;

	if (getpeereid(fd, &euid, &egid) == 0)
		return euid;
#else
	(void)fd;
#endif
	return -1;
#endif
}

155 156 157 158 159 160
inline void
OneServerSocket::Accept()
{
	struct sockaddr_storage peer_address;
	size_t peer_address_length = sizeof(peer_address);
	int peer_fd =
161
		accept_cloexec_nonblock(Get(), (struct sockaddr*)&peer_address,
162 163 164
					&peer_address_length);
	if (peer_fd < 0) {
		const SocketErrorMessage msg;
165 166
		FormatError(server_socket_domain,
			    "accept() failed: %s", (const char *)msg);
167 168 169 170 171
		return;
	}

	if (socket_keepalive(peer_fd)) {
		const SocketErrorMessage msg;
172 173 174
		FormatError(server_socket_domain,
			    "Could not set TCP keepalive option: %s",
			    (const char *)msg);
175 176
	}

177 178 179
	parent.OnAccept(peer_fd,
			(const sockaddr &)peer_address,
			peer_address_length, get_remote_uid(peer_fd));
180 181
}

182
bool
183
OneServerSocket::OnSocketReady(gcc_unused unsigned flags)
184
{
185
	Accept();
186
	return true;
187 188
}

189
inline bool
190
OneServerSocket::Open(Error &error)
191
{
192
	assert(!IsDefined());
193 194 195 196

	int _fd = socket_bind_listen(address->sa_family,
				     SOCK_STREAM, 0,
				     address, address_length, 5,
197
				     error);
198 199 200 201 202
	if (_fd < 0)
		return false;

	/* allow everybody to connect */

203 204
	if (!path.IsNull())
		chmod(path.c_str(), 0666);
205 206 207 208 209 210 211 212

	/* register in the GLib main loop */

	SetFD(_fd);

	return true;
}

213 214
ServerSocket::ServerSocket(EventLoop &_loop)
	:loop(_loop), next_serial(1) {}
215 216 217 218 219

/* this is just here to allow the OneServerSocket forward
   declaration */
ServerSocket::~ServerSocket() {}

220
bool
221
ServerSocket::Open(Error &error)
222
{
223
	OneServerSocket *good = nullptr, *bad = nullptr;
224
	Error last_error;
225

226 227
	for (auto &i : sockets) {
		assert(i.GetSerial() > 0);
228
		assert(good == nullptr || i.GetSerial() >= good->GetSerial());
229

230 231
		if (bad != nullptr && i.GetSerial() != bad->GetSerial()) {
			Close();
232
			error = std::move(last_error);
233 234 235
			return false;
		}

236 237
		Error error2;
		if (!i.Open(error2)) {
238
			if (good != nullptr && good->GetSerial() == i.GetSerial()) {
239 240
				const auto address_string = i.ToString();
				const auto good_string = good->ToString();
241 242 243 244
				FormatWarning(server_socket_domain,
					      "bind to '%s' failed: %s "
					      "(continuing anyway, because "
					      "binding to '%s' succeeded)",
245 246 247
					      address_string.c_str(),
					      error2.GetMessage(),
					      good_string.c_str());
248 249
			} else if (bad == nullptr) {
				bad = &i;
250

251
				const auto address_string = i.ToString();
252
				error2.FormatPrefix("Failed to bind to '%s': ",
253
						    address_string.c_str());
254 255 256 257

				last_error = std::move(error2);
			}

258 259 260 261 262 263
			continue;
		}

		/* mark this socket as "good", and clear previous
		   errors */

264
		good = &i;
265

266 267
		if (bad != nullptr) {
			bad = nullptr;
268
			last_error.Clear();
269 270 271
		}
	}

272
	if (bad != nullptr) {
273
		Close();
274
		error = std::move(last_error);
275 276 277 278 279 280
		return false;
	}

	return true;
}

281
void
282
ServerSocket::Close()
283 284
{
	for (auto &i : sockets)
285 286
		if (i.IsDefined())
			i.Close();
287 288
}

289 290
OneServerSocket &
ServerSocket::AddAddress(const sockaddr &address, size_t address_length)
291
{
292 293
	sockets.emplace_back(loop, *this, next_serial,
			     &address, address_length);
294

295
	return sockets.back();
296 297
}

298
bool
299
ServerSocket::AddFD(int fd, Error &error)
300 301 302 303
{
	assert(fd >= 0);

	struct sockaddr_storage address;
304
	socklen_t address_length = sizeof(address);
305 306
	if (getsockname(fd, (struct sockaddr *)&address,
			&address_length) < 0) {
307 308
		SetSocketError(error);
		error.AddPrefix("Failed to get socket address: ");
309 310 311
		return false;
	}

312 313
	OneServerSocket &s = AddAddress((const sockaddr &)address,
					address_length);
314
	s.SetFD(fd);
315 316 317 318

	return true;
}

319 320
#ifdef HAVE_TCP

321 322
inline void
ServerSocket::AddPortIPv4(unsigned port)
323 324 325 326 327 328 329
{
	struct sockaddr_in sin;
	memset(&sin, 0, sizeof(sin));
	sin.sin_port = htons(port);
	sin.sin_family = AF_INET;
	sin.sin_addr.s_addr = INADDR_ANY;

330
	AddAddress((const sockaddr &)sin, sizeof(sin));
331 332 333
}

#ifdef HAVE_IPV6
334

335 336
inline void
ServerSocket::AddPortIPv6(unsigned port)
337 338 339 340 341 342
{
	struct sockaddr_in6 sin;
	memset(&sin, 0, sizeof(sin));
	sin.sin6_port = htons(port);
	sin.sin6_family = AF_INET6;

343
	AddAddress((const sockaddr &)sin, sizeof(sin));
344
}
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360

/**
 * Is IPv6 supported by the kernel?
 */
gcc_pure
static bool
SupportsIPv6()
{
	int fd = socket(AF_INET6, SOCK_STREAM, 0);
	if (fd < 0)
		return false;

	close(fd);
	return true;
}

361 362 363 364 365
#endif /* HAVE_IPV6 */

#endif /* HAVE_TCP */

bool
366
ServerSocket::AddPort(unsigned port, Error &error)
367 368 369
{
#ifdef HAVE_TCP
	if (port == 0 || port > 0xffff) {
370
		error.Set(server_socket_domain, "Invalid TCP port");
371 372 373 374
		return false;
	}

#ifdef HAVE_IPV6
375 376
	if (SupportsIPv6())
		AddPortIPv6(port);
377
#endif
378
	AddPortIPv4(port);
379

380
	++next_serial;
381 382 383 384 385

	return true;
#else /* HAVE_TCP */
	(void)port;

386
	error.Set(server_socket_domain, "TCP support is disabled");
387 388 389 390 391
	return false;
#endif /* HAVE_TCP */
}

bool
392
ServerSocket::AddHost(const char *hostname, unsigned port, Error &error)
393 394
{
#ifdef HAVE_TCP
395 396
	struct addrinfo *ai = resolve_host_port(hostname, port,
						AI_PASSIVE, SOCK_STREAM,
397
						error);
398
	if (ai == nullptr)
399 400
		return false;

401
	for (const struct addrinfo *i = ai; i != nullptr; i = i->ai_next)
402
		AddAddress(*i->ai_addr, i->ai_addrlen);
403 404 405

	freeaddrinfo(ai);

406
	++next_serial;
407 408 409 410 411 412

	return true;
#else /* HAVE_TCP */
	(void)hostname;
	(void)port;

413
	error.Set(server_socket_domain, "TCP support is disabled");
414 415 416 417 418
	return false;
#endif /* HAVE_TCP */
}

bool
419
ServerSocket::AddPath(AllocatedPath &&path, Error &error)
420 421 422 423
{
#ifdef HAVE_UN
	struct sockaddr_un s_un;

424
	const size_t path_length = path.length();
425
	if (path_length >= sizeof(s_un.sun_path)) {
426 427
		error.Set(server_socket_domain,
			  "UNIX socket path is too long");
428 429 430
		return false;
	}

431
	RemoveFile(path);
432 433

	s_un.sun_family = AF_UNIX;
434
	memcpy(s_un.sun_path, path.c_str(), path_length + 1);
435

436
	OneServerSocket &s = AddAddress((const sockaddr &)s_un, sizeof(s_un));
437
	s.SetPath(std::move(path));
438 439 440 441 442

	return true;
#else /* !HAVE_UN */
	(void)path;

443 444
	error.Set(server_socket_domain,
		  "UNIX domain socket support is disabled");
445 446 447 448
	return false;
#endif /* !HAVE_UN */
}