ServerSocket.cxx 8.05 KB
Newer Older
1
/*
2
 * Copyright 2003-2018 The Music Player Daemon Project
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
 * http://www.musicpd.org
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

#include "config.h"
21
#include "ServerSocket.hxx"
22
#include "net/IPv4Address.hxx"
23
#include "net/IPv6Address.hxx"
24
#include "net/StaticSocketAddress.hxx"
25
#include "net/AllocatedSocketAddress.hxx"
26
#include "net/SocketAddress.hxx"
27 28
#include "net/SocketUtil.hxx"
#include "net/SocketError.hxx"
29
#include "net/UniqueSocketDescriptor.hxx"
30
#include "net/Resolver.hxx"
31
#include "net/AddressInfo.hxx"
32
#include "net/ToString.hxx"
33
#include "event/SocketMonitor.hxx"
34
#include "fs/AllocatedPath.hxx"
35
#include "util/RuntimeError.hxx"
36
#include "util/Domain.hxx"
37
#include "Log.hxx"
38

39
#include <string>
40
#include <algorithm>
41

42 43
#include <assert.h>

44 45
#ifdef HAVE_UN
#include <sys/stat.h>
46 47
#endif

48
class ServerSocket::OneServerSocket final : private SocketMonitor {
49
	ServerSocket &parent;
50

51
	const unsigned serial;
52

53
#ifdef HAVE_UN
54
	AllocatedPath path;
55
#endif
56

57
	const AllocatedSocketAddress address;
58

59
public:
60
	template<typename A>
61
	OneServerSocket(EventLoop &_loop, ServerSocket &_parent,
62
			unsigned _serial,
63
			A &&_address) noexcept
64 65
		:SocketMonitor(_loop),
		 parent(_parent), serial(_serial),
66
#ifdef HAVE_UN
67
		 path(nullptr),
68
#endif
69
		 address(std::forward<A>(_address))
70 71 72 73 74 75
	{
	}

	OneServerSocket(const OneServerSocket &other) = delete;
	OneServerSocket &operator=(const OneServerSocket &other) = delete;

76
	~OneServerSocket() noexcept {
77 78
		if (IsDefined())
			Close();
79 80
	}

81
	unsigned GetSerial() const noexcept {
82 83 84
		return serial;
	}

85
#ifdef HAVE_UN
86
	void SetPath(AllocatedPath &&_path) noexcept {
87
		assert(path.IsNull());
88

89
		path = std::move(_path);
90
	}
91
#endif
92

93
	void Open();
94

95
	using SocketMonitor::IsDefined;
96
	using SocketMonitor::Close;
97

98
	gcc_pure
99
	std::string ToString() const noexcept {
100
		return ::ToString(address);
101
	}
102

103 104
	void SetFD(UniqueSocketDescriptor _fd) noexcept {
		SocketMonitor::Open(_fd.Release());
105 106
		SocketMonitor::ScheduleRead();
	}
107

108
	void Accept() noexcept;
109 110

private:
111
	bool OnSocketReady(unsigned flags) noexcept override;
112 113
};

114
static constexpr Domain server_socket_domain("server_socket");
115 116 117 118 119 120 121 122 123

static int
get_remote_uid(int fd)
{
#ifdef HAVE_STRUCT_UCRED
	struct ucred cred;
	socklen_t len = sizeof (cred);

	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &cred, &len) < 0)
124
		return -1;
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

	return cred.uid;
#else
#ifdef HAVE_GETPEEREID
	uid_t euid;
	gid_t egid;

	if (getpeereid(fd, &euid, &egid) == 0)
		return euid;
#else
	(void)fd;
#endif
	return -1;
#endif
}

141
inline void
142
ServerSocket::OneServerSocket::Accept() noexcept
143
{
144
	StaticSocketAddress peer_address;
145
	UniqueSocketDescriptor peer_fd(GetSocket().AcceptNonBlock(peer_address));
146
	if (!peer_fd.IsDefined()) {
147
		const SocketErrorMessage msg;
148 149
		FormatError(server_socket_domain,
			    "accept() failed: %s", (const char *)msg);
150 151 152
		return;
	}

153
	if (!peer_fd.SetKeepAlive()) {
154
		const SocketErrorMessage msg;
155 156 157
		FormatError(server_socket_domain,
			    "Could not set TCP keepalive option: %s",
			    (const char *)msg);
158 159
	}

160 161 162
	const auto uid = get_remote_uid(peer_fd.Get());

	parent.OnAccept(std::move(peer_fd), peer_address, uid);
163 164
}

165
bool
166
ServerSocket::OneServerSocket::OnSocketReady(gcc_unused unsigned flags) noexcept
167
{
168
	Accept();
169
	return true;
170 171
}

172
inline void
173
ServerSocket::OneServerSocket::Open()
174
{
175
	assert(!IsDefined());
176

177 178 179
	auto _fd = socket_bind_listen(address.GetFamily(),
				      SOCK_STREAM, 0,
				      address, 5);
180

181 182 183 184 185 186 187
#ifdef HAVE_UN
	/* allow everybody to connect */

	if (!path.IsNull())
		chmod(path.c_str(), 0666);
#endif

188
	/* register in the EventLoop */	
189

190
	SetFD(std::move(_fd));
191 192
}

193
ServerSocket::ServerSocket(EventLoop &_loop) noexcept
194
	:loop(_loop) {}
195 196 197

/* this is just here to allow the OneServerSocket forward
   declaration */
198
ServerSocket::~ServerSocket() noexcept = default;
199

200 201
void
ServerSocket::Open()
202
{
203
	OneServerSocket *good = nullptr, *bad = nullptr;
204
	std::exception_ptr last_error;
205

206 207
	for (auto &i : sockets) {
		assert(i.GetSerial() > 0);
208
		assert(good == nullptr || i.GetSerial() >= good->GetSerial());
209

210 211 212 213 214
		if (i.IsDefined())
			/* already open - was probably added by
			   AddFD() */
			continue;

215 216
		if (bad != nullptr && i.GetSerial() != bad->GetSerial()) {
			Close();
217
			std::rethrow_exception(last_error);
218 219
		}

220 221
		try {
			i.Open();
222
		} catch (...) {
223
			if (good != nullptr && good->GetSerial() == i.GetSerial()) {
224 225
				const auto address_string = i.ToString();
				const auto good_string = good->ToString();
226
				FormatError(std::current_exception(),
227 228 229 230 231
					    "bind to '%s' failed "
					    "(continuing anyway, because "
					    "binding to '%s' succeeded)",
					    address_string.c_str(),
					    good_string.c_str());
232 233
			} else if (bad == nullptr) {
				bad = &i;
234

235
				const auto address_string = i.ToString();
236

237 238 239 240 241 242
				try {
					std::throw_with_nested(FormatRuntimeError("Failed to bind to '%s'",
										  address_string.c_str()));
				} catch (...) {
					last_error = std::current_exception();
				}
243 244
			}

245 246 247 248 249 250
			continue;
		}

		/* mark this socket as "good", and clear previous
		   errors */

251
		good = &i;
252

253 254
		if (bad != nullptr) {
			bad = nullptr;
255
			last_error = nullptr;
256 257 258
		}
	}

259
	if (bad != nullptr) {
260
		Close();
261
		std::rethrow_exception(last_error);
262 263 264
	}
}

265
void
266
ServerSocket::Close() noexcept
267 268
{
	for (auto &i : sockets)
269 270
		if (i.IsDefined())
			i.Close();
271 272
}

273
template<typename A>
274
ServerSocket::OneServerSocket &
275
ServerSocket::AddAddress(A &&address) noexcept
276
{
277
	sockets.emplace_back(loop, *this, next_serial,
278
			     std::forward<A>(address));
279 280 281 282

	return sockets.back();
}

283
void
284
ServerSocket::AddFD(UniqueSocketDescriptor fd)
285
{
286
	assert(fd.IsDefined());
287

288 289 290
	StaticSocketAddress address = fd.GetLocalAddress();
	if (!address.IsDefined())
		throw MakeSocketError("Failed to get socket address");
291 292

	OneServerSocket &s = AddAddress(address);
293
	s.SetFD(std::move(fd));
294 295
}

296 297 298 299 300 301 302 303 304 305 306 307
void
ServerSocket::AddFD(UniqueSocketDescriptor fd,
		    AllocatedSocketAddress &&address) noexcept
{
	assert(fd.IsDefined());
	assert(!address.IsNull());
	assert(address.IsDefined());

	OneServerSocket &s = AddAddress(std::move(address));
	s.SetFD(std::move(fd));
}

308 309
#ifdef HAVE_TCP

310
inline void
311
ServerSocket::AddPortIPv4(unsigned port) noexcept
312
{
313
	AddAddress(IPv4Address(port));
314 315 316
}

#ifdef HAVE_IPV6
317

318
inline void
319
ServerSocket::AddPortIPv6(unsigned port) noexcept
320
{
321
	AddAddress(IPv6Address(port));
322
}
323 324 325 326 327 328

/**
 * Is IPv6 supported by the kernel?
 */
gcc_pure
static bool
329
SupportsIPv6() noexcept
330 331 332 333 334 335 336 337 338
{
	int fd = socket(AF_INET6, SOCK_STREAM, 0);
	if (fd < 0)
		return false;

	close(fd);
	return true;
}

339 340 341 342
#endif /* HAVE_IPV6 */

#endif /* HAVE_TCP */

343 344
void
ServerSocket::AddPort(unsigned port)
345 346
{
#ifdef HAVE_TCP
347 348
	if (port == 0 || port > 0xffff)
		throw std::runtime_error("Invalid TCP port");
349 350

#ifdef HAVE_IPV6
351 352
	if (SupportsIPv6())
		AddPortIPv6(port);
353
#endif
354
	AddPortIPv4(port);
355

356
	++next_serial;
357 358 359
#else /* HAVE_TCP */
	(void)port;

360
	throw std::runtime_error("TCP support is disabled");
361 362 363
#endif /* HAVE_TCP */
}

364 365
void
ServerSocket::AddHost(const char *hostname, unsigned port)
366 367
{
#ifdef HAVE_TCP
368 369 370
	for (const auto &i : Resolve(hostname, port,
				     AI_PASSIVE, SOCK_STREAM))
		AddAddress(i);
371

372
	++next_serial;
373 374 375 376
#else /* HAVE_TCP */
	(void)hostname;
	(void)port;

377
	throw std::runtime_error("TCP support is disabled");
378 379 380
#endif /* HAVE_TCP */
}

381 382
void
ServerSocket::AddPath(AllocatedPath &&path)
383 384
{
#ifdef HAVE_UN
385
	unlink(path.c_str());
386

387 388
	AllocatedSocketAddress address;
	address.SetLocal(path.c_str());
389

390
	OneServerSocket &s = AddAddress(std::move(address));
391
	s.SetPath(std::move(path));
392 393 394
#else /* !HAVE_UN */
	(void)path;

395
	throw std::runtime_error("UNIX domain socket support is disabled");
396 397 398
#endif /* !HAVE_UN */
}