ServerSocket.cxx 8.42 KB
Newer Older
1
/*
Max Kellermann's avatar
Max Kellermann committed
2
 * Copyright 2003-2019 The Music Player Daemon Project
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
 * http://www.musicpd.org
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

#include "config.h"
21
#include "ServerSocket.hxx"
22
#include "net/IPv4Address.hxx"
23
#include "net/IPv6Address.hxx"
24
#include "net/StaticSocketAddress.hxx"
25
#include "net/AllocatedSocketAddress.hxx"
26 27
#include "net/SocketUtil.hxx"
#include "net/SocketError.hxx"
28
#include "net/UniqueSocketDescriptor.hxx"
29
#include "net/Resolver.hxx"
30
#include "net/AddressInfo.hxx"
31
#include "net/ToString.hxx"
32
#include "event/SocketMonitor.hxx"
33
#include "fs/AllocatedPath.hxx"
34
#include "util/RuntimeError.hxx"
35
#include "util/Domain.hxx"
36
#include "Log.hxx"
37

38
#include <string>
39
#include <utility>
40

41 42
#include <assert.h>

43 44
#ifdef HAVE_UN
#include <sys/stat.h>
45 46
#endif

47
class ServerSocket::OneServerSocket final : private SocketMonitor {
48
	ServerSocket &parent;
49

50
	const unsigned serial;
51

52
#ifdef HAVE_UN
53
	AllocatedPath path;
54
#endif
55

56
	const AllocatedSocketAddress address;
57

58
public:
59
	template<typename A>
60
	OneServerSocket(EventLoop &_loop, ServerSocket &_parent,
61
			unsigned _serial,
62
			A &&_address) noexcept
63 64
		:SocketMonitor(_loop),
		 parent(_parent), serial(_serial),
65
#ifdef HAVE_UN
66
		 path(nullptr),
67
#endif
68
		 address(std::forward<A>(_address))
69 70 71 72 73 74
	{
	}

	OneServerSocket(const OneServerSocket &other) = delete;
	OneServerSocket &operator=(const OneServerSocket &other) = delete;

75
	~OneServerSocket() noexcept {
76 77
		if (IsDefined())
			Close();
78 79
	}

80
	unsigned GetSerial() const noexcept {
81 82 83
		return serial;
	}

84
#ifdef HAVE_UN
85
	void SetPath(AllocatedPath &&_path) noexcept {
86
		assert(path.IsNull());
87

88
		path = std::move(_path);
89
	}
90
#endif
91

92
	void Open();
93

94
	using SocketMonitor::IsDefined;
95
	using SocketMonitor::Close;
96

97
	gcc_pure
98
	std::string ToString() const noexcept {
99
		return ::ToString(address);
100
	}
101

102 103
	void SetFD(UniqueSocketDescriptor _fd) noexcept {
		SocketMonitor::Open(_fd.Release());
104 105
		SocketMonitor::ScheduleRead();
	}
106

107
	void Accept() noexcept;
108 109

private:
110
	bool OnSocketReady(unsigned flags) noexcept override;
111 112
};

113
static constexpr Domain server_socket_domain("server_socket");
114 115 116 117 118 119 120 121 122

static int
get_remote_uid(int fd)
{
#ifdef HAVE_STRUCT_UCRED
	struct ucred cred;
	socklen_t len = sizeof (cred);

	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cred, &len) < 0)
123
		return -1;
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

	return cred.uid;
#else
#ifdef HAVE_GETPEEREID
	uid_t euid;
	gid_t egid;

	if (getpeereid(fd, &euid, &egid) == 0)
		return euid;
#else
	(void)fd;
#endif
	return -1;
#endif
}

140
inline void
141
ServerSocket::OneServerSocket::Accept() noexcept
142
{
143
	StaticSocketAddress peer_address;
144
	UniqueSocketDescriptor peer_fd(GetSocket().AcceptNonBlock(peer_address));
145
	if (!peer_fd.IsDefined()) {
146
		const SocketErrorMessage msg;
147 148
		FormatError(server_socket_domain,
			    "accept() failed: %s", (const char *)msg);
149 150 151
		return;
	}

152
	if (!peer_fd.SetKeepAlive()) {
153
		const SocketErrorMessage msg;
154 155 156
		FormatError(server_socket_domain,
			    "Could not set TCP keepalive option: %s",
			    (const char *)msg);
157 158
	}

159 160 161
	const auto uid = get_remote_uid(peer_fd.Get());

	parent.OnAccept(std::move(peer_fd), peer_address, uid);
162 163
}

164
bool
165
ServerSocket::OneServerSocket::OnSocketReady(gcc_unused unsigned flags) noexcept
166
{
167
	Accept();
168
	return true;
169 170
}

171
inline void
172
ServerSocket::OneServerSocket::Open()
173
{
174
	assert(!IsDefined());
175

176 177 178
	auto _fd = socket_bind_listen(address.GetFamily(),
				      SOCK_STREAM, 0,
				      address, 5);
179

180 181 182 183 184 185 186
#ifdef HAVE_UN
	/* allow everybody to connect */

	if (!path.IsNull())
		chmod(path.c_str(), 0666);
#endif

187
	/* register in the EventLoop */	
188

189
	SetFD(std::move(_fd));
190 191
}

192
ServerSocket::ServerSocket(EventLoop &_loop) noexcept
193
	:loop(_loop) {}
194 195 196

/* this is just here to allow the OneServerSocket forward
   declaration */
197
ServerSocket::~ServerSocket() noexcept = default;
198

199 200
void
ServerSocket::Open()
201
{
202
	OneServerSocket *good = nullptr, *bad = nullptr;
203
	std::exception_ptr last_error;
204

205 206
	for (auto &i : sockets) {
		assert(i.GetSerial() > 0);
207
		assert(good == nullptr || i.GetSerial() >= good->GetSerial());
208

209 210 211 212 213
		if (i.IsDefined())
			/* already open - was probably added by
			   AddFD() */
			continue;

214 215
		if (bad != nullptr && i.GetSerial() != bad->GetSerial()) {
			Close();
216
			std::rethrow_exception(last_error);
217 218
		}

219 220
		try {
			i.Open();
221
		} catch (...) {
222
			if (good != nullptr && good->GetSerial() == i.GetSerial()) {
223 224
				const auto address_string = i.ToString();
				const auto good_string = good->ToString();
225
				FormatError(std::current_exception(),
226 227 228 229 230
					    "bind to '%s' failed "
					    "(continuing anyway, because "
					    "binding to '%s' succeeded)",
					    address_string.c_str(),
					    good_string.c_str());
231 232
			} else if (bad == nullptr) {
				bad = &i;
233

234
				const auto address_string = i.ToString();
235

236 237 238 239 240 241
				try {
					std::throw_with_nested(FormatRuntimeError("Failed to bind to '%s'",
										  address_string.c_str()));
				} catch (...) {
					last_error = std::current_exception();
				}
242 243
			}

244 245 246 247 248 249
			continue;
		}

		/* mark this socket as "good", and clear previous
		   errors */

250
		good = &i;
251

252 253
		if (bad != nullptr) {
			bad = nullptr;
254
			last_error = nullptr;
255 256 257
		}
	}

258
	if (bad != nullptr) {
259
		Close();
260
		std::rethrow_exception(last_error);
261 262 263
	}
}

264
void
265
ServerSocket::Close() noexcept
266 267
{
	for (auto &i : sockets)
268 269
		if (i.IsDefined())
			i.Close();
270 271
}

272
template<typename A>
273
ServerSocket::OneServerSocket &
274
ServerSocket::AddAddress(A &&address) noexcept
275
{
276
	sockets.emplace_back(loop, *this, next_serial,
277
			     std::forward<A>(address));
278 279 280 281

	return sockets.back();
}

282
void
283
ServerSocket::AddFD(UniqueSocketDescriptor fd)
284
{
285
	assert(fd.IsDefined());
286

287 288 289
	StaticSocketAddress address = fd.GetLocalAddress();
	if (!address.IsDefined())
		throw MakeSocketError("Failed to get socket address");
290 291

	OneServerSocket &s = AddAddress(address);
292
	s.SetFD(std::move(fd));
293 294
}

295 296 297 298 299 300 301 302 303 304 305 306
void
ServerSocket::AddFD(UniqueSocketDescriptor fd,
		    AllocatedSocketAddress &&address) noexcept
{
	assert(fd.IsDefined());
	assert(!address.IsNull());
	assert(address.IsDefined());

	OneServerSocket &s = AddAddress(std::move(address));
	s.SetFD(std::move(fd));
}

307 308
#ifdef HAVE_TCP

309
inline void
310
ServerSocket::AddPortIPv4(unsigned port) noexcept
311
{
312
	AddAddress(IPv4Address(port));
313 314 315
}

#ifdef HAVE_IPV6
316

317
inline void
318
ServerSocket::AddPortIPv6(unsigned port) noexcept
319
{
320
	AddAddress(IPv6Address(port));
321
}
322 323 324 325 326 327

/**
 * Is IPv6 supported by the kernel?
 */
gcc_pure
static bool
328
SupportsIPv6() noexcept
329 330 331 332 333 334 335 336 337
{
	int fd = socket(AF_INET6, SOCK_STREAM, 0);
	if (fd < 0)
		return false;

	close(fd);
	return true;
}

338 339 340 341
#endif /* HAVE_IPV6 */

#endif /* HAVE_TCP */

342 343
void
ServerSocket::AddPort(unsigned port)
344 345
{
#ifdef HAVE_TCP
346 347
	if (port == 0 || port > 0xffff)
		throw std::runtime_error("Invalid TCP port");
348 349

#ifdef HAVE_IPV6
350 351
	if (SupportsIPv6())
		AddPortIPv6(port);
352
#endif
353
	AddPortIPv4(port);
354

355
	++next_serial;
356 357 358
#else /* HAVE_TCP */
	(void)port;

359
	throw std::runtime_error("TCP support is disabled");
360 361 362
#endif /* HAVE_TCP */
}

363 364
void
ServerSocket::AddHost(const char *hostname, unsigned port)
365 366
{
#ifdef HAVE_TCP
367 368 369
	for (const auto &i : Resolve(hostname, port,
				     AI_PASSIVE, SOCK_STREAM))
		AddAddress(i);
370

371
	++next_serial;
372 373 374 375
#else /* HAVE_TCP */
	(void)hostname;
	(void)port;

376
	throw std::runtime_error("TCP support is disabled");
377 378 379
#endif /* HAVE_TCP */
}

380 381
void
ServerSocket::AddPath(AllocatedPath &&path)
382 383
{
#ifdef HAVE_UN
384
	unlink(path.c_str());
385

386 387
	AllocatedSocketAddress address;
	address.SetLocal(path.c_str());
388

389
	OneServerSocket &s = AddAddress(std::move(address));
390
	s.SetPath(std::move(path));
391 392 393
#else /* !HAVE_UN */
	(void)path;

394
	throw std::runtime_error("Local socket support is disabled");
395 396 397
#endif /* !HAVE_UN */
}

398 399 400 401

void
ServerSocket::AddAbstract(const char *name)
{
402 403 404 405 406
#if !defined(__linux__)
	(void)name;

	throw std::runtime_error("Abstract sockets are only available on Linux");
#elif !defined(HAVE_UN)
407 408 409 410
	(void)name;

	throw std::runtime_error("Local socket support is disabled");
#else
411 412 413 414 415 416 417
	assert(name != nullptr);
	assert(*name == '@');

	AllocatedSocketAddress address;
	address.SetLocal(name);

	AddAddress(std::move(address));
418
#endif
419
}