1/*
2 * inet_diag.c	Module for monitoring INET transport protocols sockets.
3 *
4 * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5 *
6 *	This program is free software; you can redistribute it and/or
7 *      modify it under the terms of the GNU General Public License
8 *      as published by the Free Software Foundation; either version
9 *      2 of the License, or (at your option) any later version.
10 */
11
12#include <linux/kernel.h>
13#include <linux/module.h>
14#include <linux/types.h>
15#include <linux/fcntl.h>
16#include <linux/random.h>
17#include <linux/slab.h>
18#include <linux/cache.h>
19#include <linux/init.h>
20#include <linux/time.h>
21
22#include <net/icmp.h>
23#include <net/tcp.h>
24#include <net/ipv6.h>
25#include <net/inet_common.h>
26#include <net/inet_connection_sock.h>
27#include <net/inet_hashtables.h>
28#include <net/inet_timewait_sock.h>
29#include <net/inet6_hashtables.h>
30#include <net/netlink.h>
31
32#include <linux/inet.h>
33#include <linux/stddef.h>
34
35#include <linux/inet_diag.h>
36#include <linux/sock_diag.h>
37
38static const struct inet_diag_handler **inet_diag_table;
39
40struct inet_diag_entry {
41	const __be32 *saddr;
42	const __be32 *daddr;
43	u16 sport;
44	u16 dport;
45	u16 family;
46	u16 userlocks;
47};
48
49static DEFINE_MUTEX(inet_diag_table_mutex);
50
51static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
52{
53	if (!inet_diag_table[proto])
54		request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
55			       NETLINK_SOCK_DIAG, AF_INET, proto);
56
57	mutex_lock(&inet_diag_table_mutex);
58	if (!inet_diag_table[proto])
59		return ERR_PTR(-ENOENT);
60
61	return inet_diag_table[proto];
62}
63
64static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
65{
66	mutex_unlock(&inet_diag_table_mutex);
67}
68
69static void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
70{
71	r->idiag_family = sk->sk_family;
72
73	r->id.idiag_sport = htons(sk->sk_num);
74	r->id.idiag_dport = sk->sk_dport;
75	r->id.idiag_if = sk->sk_bound_dev_if;
76	sock_diag_save_cookie(sk, r->id.idiag_cookie);
77
78#if IS_ENABLED(CONFIG_IPV6)
79	if (sk->sk_family == AF_INET6) {
80		*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
81		*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
82	} else
83#endif
84	{
85	memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
86	memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
87
88	r->id.idiag_src[0] = sk->sk_rcv_saddr;
89	r->id.idiag_dst[0] = sk->sk_daddr;
90	}
91}
92
93static size_t inet_sk_attr_size(void)
94{
95	return	  nla_total_size(sizeof(struct tcp_info))
96		+ nla_total_size(1) /* INET_DIAG_SHUTDOWN */
97		+ nla_total_size(1) /* INET_DIAG_TOS */
98		+ nla_total_size(1) /* INET_DIAG_TCLASS */
99		+ nla_total_size(sizeof(struct inet_diag_meminfo))
100		+ nla_total_size(sizeof(struct inet_diag_msg))
101		+ nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
102		+ nla_total_size(TCP_CA_NAME_MAX)
103		+ nla_total_size(sizeof(struct tcpvegas_info))
104		+ 64;
105}
106
107int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
108		      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
109		      struct user_namespace *user_ns,
110		      u32 portid, u32 seq, u16 nlmsg_flags,
111		      const struct nlmsghdr *unlh)
112{
113	const struct inet_sock *inet = inet_sk(sk);
114	const struct tcp_congestion_ops *ca_ops;
115	const struct inet_diag_handler *handler;
116	int ext = req->idiag_ext;
117	struct inet_diag_msg *r;
118	struct nlmsghdr  *nlh;
119	struct nlattr *attr;
120	void *info = NULL;
121
122	handler = inet_diag_table[req->sdiag_protocol];
123	BUG_ON(!handler);
124
125	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
126			nlmsg_flags);
127	if (!nlh)
128		return -EMSGSIZE;
129
130	r = nlmsg_data(nlh);
131	BUG_ON(!sk_fullsock(sk));
132
133	inet_diag_msg_common_fill(r, sk);
134	r->idiag_state = sk->sk_state;
135	r->idiag_timer = 0;
136	r->idiag_retrans = 0;
137
138	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
139		goto errout;
140
141	/* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
142	 * hence this needs to be included regardless of socket family.
143	 */
144	if (ext & (1 << (INET_DIAG_TOS - 1)))
145		if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
146			goto errout;
147
148#if IS_ENABLED(CONFIG_IPV6)
149	if (r->idiag_family == AF_INET6) {
150		if (ext & (1 << (INET_DIAG_TCLASS - 1)))
151			if (nla_put_u8(skb, INET_DIAG_TCLASS,
152				       inet6_sk(sk)->tclass) < 0)
153				goto errout;
154	}
155#endif
156
157	r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
158	r->idiag_inode = sock_i_ino(sk);
159
160	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
161		struct inet_diag_meminfo minfo = {
162			.idiag_rmem = sk_rmem_alloc_get(sk),
163			.idiag_wmem = sk->sk_wmem_queued,
164			.idiag_fmem = sk->sk_forward_alloc,
165			.idiag_tmem = sk_wmem_alloc_get(sk),
166		};
167
168		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
169			goto errout;
170	}
171
172	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
173		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
174			goto errout;
175
176	if (!icsk) {
177		handler->idiag_get_info(sk, r, NULL);
178		goto out;
179	}
180
181#define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
182
183	if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
184	    icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
185	    icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
186		r->idiag_timer = 1;
187		r->idiag_retrans = icsk->icsk_retransmits;
188		r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
189	} else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
190		r->idiag_timer = 4;
191		r->idiag_retrans = icsk->icsk_probes_out;
192		r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
193	} else if (timer_pending(&sk->sk_timer)) {
194		r->idiag_timer = 2;
195		r->idiag_retrans = icsk->icsk_probes_out;
196		r->idiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
197	} else {
198		r->idiag_timer = 0;
199		r->idiag_expires = 0;
200	}
201#undef EXPIRES_IN_MS
202
203	if (ext & (1 << (INET_DIAG_INFO - 1))) {
204		attr = nla_reserve(skb, INET_DIAG_INFO,
205				   sizeof(struct tcp_info));
206		if (!attr)
207			goto errout;
208
209		info = nla_data(attr);
210	}
211
212	if (ext & (1 << (INET_DIAG_CONG - 1))) {
213		int err = 0;
214
215		rcu_read_lock();
216		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
217		if (ca_ops)
218			err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
219		rcu_read_unlock();
220		if (err < 0)
221			goto errout;
222	}
223
224	handler->idiag_get_info(sk, r, info);
225
226	if (sk->sk_state < TCP_TIME_WAIT) {
227		union tcp_cc_info info;
228		size_t sz = 0;
229		int attr;
230
231		rcu_read_lock();
232		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
233		if (ca_ops && ca_ops->get_info)
234			sz = ca_ops->get_info(sk, ext, &attr, &info);
235		rcu_read_unlock();
236		if (sz && nla_put(skb, attr, sz, &info) < 0)
237			goto errout;
238	}
239
240out:
241	nlmsg_end(skb, nlh);
242	return 0;
243
244errout:
245	nlmsg_cancel(skb, nlh);
246	return -EMSGSIZE;
247}
248EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
249
250static int inet_csk_diag_fill(struct sock *sk,
251			      struct sk_buff *skb,
252			      const struct inet_diag_req_v2 *req,
253			      struct user_namespace *user_ns,
254			      u32 portid, u32 seq, u16 nlmsg_flags,
255			      const struct nlmsghdr *unlh)
256{
257	return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
258				 user_ns, portid, seq, nlmsg_flags, unlh);
259}
260
261static int inet_twsk_diag_fill(struct sock *sk,
262			       struct sk_buff *skb,
263			       u32 portid, u32 seq, u16 nlmsg_flags,
264			       const struct nlmsghdr *unlh)
265{
266	struct inet_timewait_sock *tw = inet_twsk(sk);
267	struct inet_diag_msg *r;
268	struct nlmsghdr *nlh;
269	long tmo;
270
271	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
272			nlmsg_flags);
273	if (!nlh)
274		return -EMSGSIZE;
275
276	r = nlmsg_data(nlh);
277	BUG_ON(tw->tw_state != TCP_TIME_WAIT);
278
279	tmo = tw->tw_timer.expires - jiffies;
280	if (tmo < 0)
281		tmo = 0;
282
283	inet_diag_msg_common_fill(r, sk);
284	r->idiag_retrans      = 0;
285
286	r->idiag_state	      = tw->tw_substate;
287	r->idiag_timer	      = 3;
288	r->idiag_expires      = jiffies_to_msecs(tmo);
289	r->idiag_rqueue	      = 0;
290	r->idiag_wqueue	      = 0;
291	r->idiag_uid	      = 0;
292	r->idiag_inode	      = 0;
293
294	nlmsg_end(skb, nlh);
295	return 0;
296}
297
298static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
299			      u32 portid, u32 seq, u16 nlmsg_flags,
300			      const struct nlmsghdr *unlh)
301{
302	struct inet_diag_msg *r;
303	struct nlmsghdr *nlh;
304	long tmo;
305
306	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
307			nlmsg_flags);
308	if (!nlh)
309		return -EMSGSIZE;
310
311	r = nlmsg_data(nlh);
312	inet_diag_msg_common_fill(r, sk);
313	r->idiag_state = TCP_SYN_RECV;
314	r->idiag_timer = 1;
315	r->idiag_retrans = inet_reqsk(sk)->num_retrans;
316
317	BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
318		     offsetof(struct sock, sk_cookie));
319
320	tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
321	r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
322	r->idiag_rqueue	= 0;
323	r->idiag_wqueue	= 0;
324	r->idiag_uid	= 0;
325	r->idiag_inode	= 0;
326
327	nlmsg_end(skb, nlh);
328	return 0;
329}
330
331static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
332			const struct inet_diag_req_v2 *r,
333			struct user_namespace *user_ns,
334			u32 portid, u32 seq, u16 nlmsg_flags,
335			const struct nlmsghdr *unlh)
336{
337	if (sk->sk_state == TCP_TIME_WAIT)
338		return inet_twsk_diag_fill(sk, skb, portid, seq,
339					   nlmsg_flags, unlh);
340
341	if (sk->sk_state == TCP_NEW_SYN_RECV)
342		return inet_req_diag_fill(sk, skb, portid, seq,
343					  nlmsg_flags, unlh);
344
345	return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
346				  nlmsg_flags, unlh);
347}
348
349int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
350			    struct sk_buff *in_skb,
351			    const struct nlmsghdr *nlh,
352			    const struct inet_diag_req_v2 *req)
353{
354	struct net *net = sock_net(in_skb->sk);
355	struct sk_buff *rep;
356	struct sock *sk;
357	int err;
358
359	err = -EINVAL;
360	if (req->sdiag_family == AF_INET)
361		sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
362				 req->id.idiag_dport, req->id.idiag_src[0],
363				 req->id.idiag_sport, req->id.idiag_if);
364#if IS_ENABLED(CONFIG_IPV6)
365	else if (req->sdiag_family == AF_INET6)
366		sk = inet6_lookup(net, hashinfo,
367				  (struct in6_addr *)req->id.idiag_dst,
368				  req->id.idiag_dport,
369				  (struct in6_addr *)req->id.idiag_src,
370				  req->id.idiag_sport,
371				  req->id.idiag_if);
372#endif
373	else
374		goto out_nosk;
375
376	err = -ENOENT;
377	if (!sk)
378		goto out_nosk;
379
380	err = sock_diag_check_cookie(sk, req->id.idiag_cookie);
381	if (err)
382		goto out;
383
384	rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL);
385	if (!rep) {
386		err = -ENOMEM;
387		goto out;
388	}
389
390	err = sk_diag_fill(sk, rep, req,
391			   sk_user_ns(NETLINK_CB(in_skb).sk),
392			   NETLINK_CB(in_skb).portid,
393			   nlh->nlmsg_seq, 0, nlh);
394	if (err < 0) {
395		WARN_ON(err == -EMSGSIZE);
396		nlmsg_free(rep);
397		goto out;
398	}
399	err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
400			      MSG_DONTWAIT);
401	if (err > 0)
402		err = 0;
403
404out:
405	if (sk)
406		sock_gen_put(sk);
407
408out_nosk:
409	return err;
410}
411EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
412
413static int inet_diag_get_exact(struct sk_buff *in_skb,
414			       const struct nlmsghdr *nlh,
415			       const struct inet_diag_req_v2 *req)
416{
417	const struct inet_diag_handler *handler;
418	int err;
419
420	handler = inet_diag_lock_handler(req->sdiag_protocol);
421	if (IS_ERR(handler))
422		err = PTR_ERR(handler);
423	else
424		err = handler->dump_one(in_skb, nlh, req);
425	inet_diag_unlock_handler(handler);
426
427	return err;
428}
429
430static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
431{
432	int words = bits >> 5;
433
434	bits &= 0x1f;
435
436	if (words) {
437		if (memcmp(a1, a2, words << 2))
438			return 0;
439	}
440	if (bits) {
441		__be32 w1, w2;
442		__be32 mask;
443
444		w1 = a1[words];
445		w2 = a2[words];
446
447		mask = htonl((0xffffffff) << (32 - bits));
448
449		if ((w1 ^ w2) & mask)
450			return 0;
451	}
452
453	return 1;
454}
455
456static int inet_diag_bc_run(const struct nlattr *_bc,
457			    const struct inet_diag_entry *entry)
458{
459	const void *bc = nla_data(_bc);
460	int len = nla_len(_bc);
461
462	while (len > 0) {
463		int yes = 1;
464		const struct inet_diag_bc_op *op = bc;
465
466		switch (op->code) {
467		case INET_DIAG_BC_NOP:
468			break;
469		case INET_DIAG_BC_JMP:
470			yes = 0;
471			break;
472		case INET_DIAG_BC_S_GE:
473			yes = entry->sport >= op[1].no;
474			break;
475		case INET_DIAG_BC_S_LE:
476			yes = entry->sport <= op[1].no;
477			break;
478		case INET_DIAG_BC_D_GE:
479			yes = entry->dport >= op[1].no;
480			break;
481		case INET_DIAG_BC_D_LE:
482			yes = entry->dport <= op[1].no;
483			break;
484		case INET_DIAG_BC_AUTO:
485			yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
486			break;
487		case INET_DIAG_BC_S_COND:
488		case INET_DIAG_BC_D_COND: {
489			const struct inet_diag_hostcond *cond;
490			const __be32 *addr;
491
492			cond = (const struct inet_diag_hostcond *)(op + 1);
493			if (cond->port != -1 &&
494			    cond->port != (op->code == INET_DIAG_BC_S_COND ?
495					     entry->sport : entry->dport)) {
496				yes = 0;
497				break;
498			}
499
500			if (op->code == INET_DIAG_BC_S_COND)
501				addr = entry->saddr;
502			else
503				addr = entry->daddr;
504
505			if (cond->family != AF_UNSPEC &&
506			    cond->family != entry->family) {
507				if (entry->family == AF_INET6 &&
508				    cond->family == AF_INET) {
509					if (addr[0] == 0 && addr[1] == 0 &&
510					    addr[2] == htonl(0xffff) &&
511					    bitstring_match(addr + 3,
512							    cond->addr,
513							    cond->prefix_len))
514						break;
515				}
516				yes = 0;
517				break;
518			}
519
520			if (cond->prefix_len == 0)
521				break;
522			if (bitstring_match(addr, cond->addr,
523					    cond->prefix_len))
524				break;
525			yes = 0;
526			break;
527		}
528		}
529
530		if (yes) {
531			len -= op->yes;
532			bc += op->yes;
533		} else {
534			len -= op->no;
535			bc += op->no;
536		}
537	}
538	return len == 0;
539}
540
541/* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
542 */
543static void entry_fill_addrs(struct inet_diag_entry *entry,
544			     const struct sock *sk)
545{
546#if IS_ENABLED(CONFIG_IPV6)
547	if (sk->sk_family == AF_INET6) {
548		entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
549		entry->daddr = sk->sk_v6_daddr.s6_addr32;
550	} else
551#endif
552	{
553		entry->saddr = &sk->sk_rcv_saddr;
554		entry->daddr = &sk->sk_daddr;
555	}
556}
557
558int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
559{
560	struct inet_sock *inet = inet_sk(sk);
561	struct inet_diag_entry entry;
562
563	if (!bc)
564		return 1;
565
566	entry.family = sk->sk_family;
567	entry_fill_addrs(&entry, sk);
568	entry.sport = inet->inet_num;
569	entry.dport = ntohs(inet->inet_dport);
570	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
571
572	return inet_diag_bc_run(bc, &entry);
573}
574EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
575
576static int valid_cc(const void *bc, int len, int cc)
577{
578	while (len >= 0) {
579		const struct inet_diag_bc_op *op = bc;
580
581		if (cc > len)
582			return 0;
583		if (cc == len)
584			return 1;
585		if (op->yes < 4 || op->yes & 3)
586			return 0;
587		len -= op->yes;
588		bc  += op->yes;
589	}
590	return 0;
591}
592
593/* Validate an inet_diag_hostcond. */
594static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
595			   int *min_len)
596{
597	struct inet_diag_hostcond *cond;
598	int addr_len;
599
600	/* Check hostcond space. */
601	*min_len += sizeof(struct inet_diag_hostcond);
602	if (len < *min_len)
603		return false;
604	cond = (struct inet_diag_hostcond *)(op + 1);
605
606	/* Check address family and address length. */
607	switch (cond->family) {
608	case AF_UNSPEC:
609		addr_len = 0;
610		break;
611	case AF_INET:
612		addr_len = sizeof(struct in_addr);
613		break;
614	case AF_INET6:
615		addr_len = sizeof(struct in6_addr);
616		break;
617	default:
618		return false;
619	}
620	*min_len += addr_len;
621	if (len < *min_len)
622		return false;
623
624	/* Check prefix length (in bits) vs address length (in bytes). */
625	if (cond->prefix_len > 8 * addr_len)
626		return false;
627
628	return true;
629}
630
631/* Validate a port comparison operator. */
632static bool valid_port_comparison(const struct inet_diag_bc_op *op,
633				  int len, int *min_len)
634{
635	/* Port comparisons put the port in a follow-on inet_diag_bc_op. */
636	*min_len += sizeof(struct inet_diag_bc_op);
637	if (len < *min_len)
638		return false;
639	return true;
640}
641
642static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
643{
644	const void *bc = bytecode;
645	int  len = bytecode_len;
646
647	while (len > 0) {
648		int min_len = sizeof(struct inet_diag_bc_op);
649		const struct inet_diag_bc_op *op = bc;
650
651		switch (op->code) {
652		case INET_DIAG_BC_S_COND:
653		case INET_DIAG_BC_D_COND:
654			if (!valid_hostcond(bc, len, &min_len))
655				return -EINVAL;
656			break;
657		case INET_DIAG_BC_S_GE:
658		case INET_DIAG_BC_S_LE:
659		case INET_DIAG_BC_D_GE:
660		case INET_DIAG_BC_D_LE:
661			if (!valid_port_comparison(bc, len, &min_len))
662				return -EINVAL;
663			break;
664		case INET_DIAG_BC_AUTO:
665		case INET_DIAG_BC_JMP:
666		case INET_DIAG_BC_NOP:
667			break;
668		default:
669			return -EINVAL;
670		}
671
672		if (op->code != INET_DIAG_BC_NOP) {
673			if (op->no < min_len || op->no > len + 4 || op->no & 3)
674				return -EINVAL;
675			if (op->no < len &&
676			    !valid_cc(bytecode, bytecode_len, len - op->no))
677				return -EINVAL;
678		}
679
680		if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
681			return -EINVAL;
682		bc  += op->yes;
683		len -= op->yes;
684	}
685	return len == 0 ? 0 : -EINVAL;
686}
687
688static int inet_csk_diag_dump(struct sock *sk,
689			      struct sk_buff *skb,
690			      struct netlink_callback *cb,
691			      const struct inet_diag_req_v2 *r,
692			      const struct nlattr *bc)
693{
694	if (!inet_diag_bc_sk(bc, sk))
695		return 0;
696
697	return inet_csk_diag_fill(sk, skb, r,
698				  sk_user_ns(NETLINK_CB(cb->skb).sk),
699				  NETLINK_CB(cb->skb).portid,
700				  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
701}
702
703static void twsk_build_assert(void)
704{
705	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
706		     offsetof(struct sock, sk_family));
707
708	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
709		     offsetof(struct inet_sock, inet_num));
710
711	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
712		     offsetof(struct inet_sock, inet_dport));
713
714	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
715		     offsetof(struct inet_sock, inet_rcv_saddr));
716
717	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
718		     offsetof(struct inet_sock, inet_daddr));
719
720#if IS_ENABLED(CONFIG_IPV6)
721	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
722		     offsetof(struct sock, sk_v6_rcv_saddr));
723
724	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
725		     offsetof(struct sock, sk_v6_daddr));
726#endif
727}
728
729static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
730			       struct netlink_callback *cb,
731			       const struct inet_diag_req_v2 *r,
732			       const struct nlattr *bc)
733{
734	struct inet_connection_sock *icsk = inet_csk(sk);
735	struct inet_sock *inet = inet_sk(sk);
736	struct inet_diag_entry entry;
737	int j, s_j, reqnum, s_reqnum;
738	struct listen_sock *lopt;
739	int err = 0;
740
741	s_j = cb->args[3];
742	s_reqnum = cb->args[4];
743
744	if (s_j > 0)
745		s_j--;
746
747	entry.family = sk->sk_family;
748
749	spin_lock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
750
751	lopt = icsk->icsk_accept_queue.listen_opt;
752	if (!lopt || !listen_sock_qlen(lopt))
753		goto out;
754
755	if (bc) {
756		entry.sport = inet->inet_num;
757		entry.userlocks = sk->sk_userlocks;
758	}
759
760	for (j = s_j; j < lopt->nr_table_entries; j++) {
761		struct request_sock *req, *head = lopt->syn_table[j];
762
763		reqnum = 0;
764		for (req = head; req; reqnum++, req = req->dl_next) {
765			struct inet_request_sock *ireq = inet_rsk(req);
766
767			if (reqnum < s_reqnum)
768				continue;
769			if (r->id.idiag_dport != ireq->ir_rmt_port &&
770			    r->id.idiag_dport)
771				continue;
772
773			if (bc) {
774				/* Note: entry.sport and entry.userlocks are already set */
775				entry_fill_addrs(&entry, req_to_sk(req));
776				entry.dport = ntohs(ireq->ir_rmt_port);
777
778				if (!inet_diag_bc_run(bc, &entry))
779					continue;
780			}
781
782			err = inet_req_diag_fill(req_to_sk(req), skb,
783						 NETLINK_CB(cb->skb).portid,
784						 cb->nlh->nlmsg_seq,
785						 NLM_F_MULTI, cb->nlh);
786			if (err < 0) {
787				cb->args[3] = j + 1;
788				cb->args[4] = reqnum;
789				goto out;
790			}
791		}
792
793		s_reqnum = 0;
794	}
795
796out:
797	spin_unlock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
798
799	return err;
800}
801
802void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
803			 struct netlink_callback *cb,
804			 const struct inet_diag_req_v2 *r, struct nlattr *bc)
805{
806	struct net *net = sock_net(skb->sk);
807	int i, num, s_i, s_num;
808
809	s_i = cb->args[1];
810	s_num = num = cb->args[2];
811
812	if (cb->args[0] == 0) {
813		if (!(r->idiag_states & (TCPF_LISTEN | TCPF_SYN_RECV)))
814			goto skip_listen_ht;
815
816		for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
817			struct inet_listen_hashbucket *ilb;
818			struct hlist_nulls_node *node;
819			struct sock *sk;
820
821			num = 0;
822			ilb = &hashinfo->listening_hash[i];
823			spin_lock_bh(&ilb->lock);
824			sk_nulls_for_each(sk, node, &ilb->head) {
825				struct inet_sock *inet = inet_sk(sk);
826
827				if (!net_eq(sock_net(sk), net))
828					continue;
829
830				if (num < s_num) {
831					num++;
832					continue;
833				}
834
835				if (r->sdiag_family != AF_UNSPEC &&
836				    sk->sk_family != r->sdiag_family)
837					goto next_listen;
838
839				if (r->id.idiag_sport != inet->inet_sport &&
840				    r->id.idiag_sport)
841					goto next_listen;
842
843				if (!(r->idiag_states & TCPF_LISTEN) ||
844				    r->id.idiag_dport ||
845				    cb->args[3] > 0)
846					goto syn_recv;
847
848				if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
849					spin_unlock_bh(&ilb->lock);
850					goto done;
851				}
852
853syn_recv:
854				if (!(r->idiag_states & TCPF_SYN_RECV))
855					goto next_listen;
856
857				if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) {
858					spin_unlock_bh(&ilb->lock);
859					goto done;
860				}
861
862next_listen:
863				cb->args[3] = 0;
864				cb->args[4] = 0;
865				++num;
866			}
867			spin_unlock_bh(&ilb->lock);
868
869			s_num = 0;
870			cb->args[3] = 0;
871			cb->args[4] = 0;
872		}
873skip_listen_ht:
874		cb->args[0] = 1;
875		s_i = num = s_num = 0;
876	}
877
878	if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
879		goto out;
880
881	for (i = s_i; i <= hashinfo->ehash_mask; i++) {
882		struct inet_ehash_bucket *head = &hashinfo->ehash[i];
883		spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
884		struct hlist_nulls_node *node;
885		struct sock *sk;
886
887		num = 0;
888
889		if (hlist_nulls_empty(&head->chain))
890			continue;
891
892		if (i > s_i)
893			s_num = 0;
894
895		spin_lock_bh(lock);
896		sk_nulls_for_each(sk, node, &head->chain) {
897			int state, res;
898
899			if (!net_eq(sock_net(sk), net))
900				continue;
901			if (num < s_num)
902				goto next_normal;
903			state = (sk->sk_state == TCP_TIME_WAIT) ?
904				inet_twsk(sk)->tw_substate : sk->sk_state;
905			if (!(r->idiag_states & (1 << state)))
906				goto next_normal;
907			if (r->sdiag_family != AF_UNSPEC &&
908			    sk->sk_family != r->sdiag_family)
909				goto next_normal;
910			if (r->id.idiag_sport != htons(sk->sk_num) &&
911			    r->id.idiag_sport)
912				goto next_normal;
913			if (r->id.idiag_dport != sk->sk_dport &&
914			    r->id.idiag_dport)
915				goto next_normal;
916			twsk_build_assert();
917
918			if (!inet_diag_bc_sk(bc, sk))
919				goto next_normal;
920
921			res = sk_diag_fill(sk, skb, r,
922					   sk_user_ns(NETLINK_CB(cb->skb).sk),
923					   NETLINK_CB(cb->skb).portid,
924					   cb->nlh->nlmsg_seq, NLM_F_MULTI,
925					   cb->nlh);
926			if (res < 0) {
927				spin_unlock_bh(lock);
928				goto done;
929			}
930next_normal:
931			++num;
932		}
933
934		spin_unlock_bh(lock);
935	}
936
937done:
938	cb->args[1] = i;
939	cb->args[2] = num;
940out:
941	;
942}
943EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
944
945static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
946			    const struct inet_diag_req_v2 *r,
947			    struct nlattr *bc)
948{
949	const struct inet_diag_handler *handler;
950	int err = 0;
951
952	handler = inet_diag_lock_handler(r->sdiag_protocol);
953	if (!IS_ERR(handler))
954		handler->dump(skb, cb, r, bc);
955	else
956		err = PTR_ERR(handler);
957	inet_diag_unlock_handler(handler);
958
959	return err ? : skb->len;
960}
961
962static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
963{
964	int hdrlen = sizeof(struct inet_diag_req_v2);
965	struct nlattr *bc = NULL;
966
967	if (nlmsg_attrlen(cb->nlh, hdrlen))
968		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
969
970	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
971}
972
973static int inet_diag_type2proto(int type)
974{
975	switch (type) {
976	case TCPDIAG_GETSOCK:
977		return IPPROTO_TCP;
978	case DCCPDIAG_GETSOCK:
979		return IPPROTO_DCCP;
980	default:
981		return 0;
982	}
983}
984
985static int inet_diag_dump_compat(struct sk_buff *skb,
986				 struct netlink_callback *cb)
987{
988	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
989	int hdrlen = sizeof(struct inet_diag_req);
990	struct inet_diag_req_v2 req;
991	struct nlattr *bc = NULL;
992
993	req.sdiag_family = AF_UNSPEC; /* compatibility */
994	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
995	req.idiag_ext = rc->idiag_ext;
996	req.idiag_states = rc->idiag_states;
997	req.id = rc->id;
998
999	if (nlmsg_attrlen(cb->nlh, hdrlen))
1000		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1001
1002	return __inet_diag_dump(skb, cb, &req, bc);
1003}
1004
1005static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1006				      const struct nlmsghdr *nlh)
1007{
1008	struct inet_diag_req *rc = nlmsg_data(nlh);
1009	struct inet_diag_req_v2 req;
1010
1011	req.sdiag_family = rc->idiag_family;
1012	req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1013	req.idiag_ext = rc->idiag_ext;
1014	req.idiag_states = rc->idiag_states;
1015	req.id = rc->id;
1016
1017	return inet_diag_get_exact(in_skb, nlh, &req);
1018}
1019
1020static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1021{
1022	int hdrlen = sizeof(struct inet_diag_req);
1023	struct net *net = sock_net(skb->sk);
1024
1025	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1026	    nlmsg_len(nlh) < hdrlen)
1027		return -EINVAL;
1028
1029	if (nlh->nlmsg_flags & NLM_F_DUMP) {
1030		if (nlmsg_attrlen(nlh, hdrlen)) {
1031			struct nlattr *attr;
1032
1033			attr = nlmsg_find_attr(nlh, hdrlen,
1034					       INET_DIAG_REQ_BYTECODE);
1035			if (!attr ||
1036			    nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1037			    inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1038				return -EINVAL;
1039		}
1040		{
1041			struct netlink_dump_control c = {
1042				.dump = inet_diag_dump_compat,
1043			};
1044			return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1045		}
1046	}
1047
1048	return inet_diag_get_exact_compat(skb, nlh);
1049}
1050
1051static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
1052{
1053	int hdrlen = sizeof(struct inet_diag_req_v2);
1054	struct net *net = sock_net(skb->sk);
1055
1056	if (nlmsg_len(h) < hdrlen)
1057		return -EINVAL;
1058
1059	if (h->nlmsg_flags & NLM_F_DUMP) {
1060		if (nlmsg_attrlen(h, hdrlen)) {
1061			struct nlattr *attr;
1062
1063			attr = nlmsg_find_attr(h, hdrlen,
1064					       INET_DIAG_REQ_BYTECODE);
1065			if (!attr ||
1066			    nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1067			    inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1068				return -EINVAL;
1069		}
1070		{
1071			struct netlink_dump_control c = {
1072				.dump = inet_diag_dump,
1073			};
1074			return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1075		}
1076	}
1077
1078	return inet_diag_get_exact(skb, h, nlmsg_data(h));
1079}
1080
1081static const struct sock_diag_handler inet_diag_handler = {
1082	.family = AF_INET,
1083	.dump = inet_diag_handler_dump,
1084};
1085
1086static const struct sock_diag_handler inet6_diag_handler = {
1087	.family = AF_INET6,
1088	.dump = inet_diag_handler_dump,
1089};
1090
1091int inet_diag_register(const struct inet_diag_handler *h)
1092{
1093	const __u16 type = h->idiag_type;
1094	int err = -EINVAL;
1095
1096	if (type >= IPPROTO_MAX)
1097		goto out;
1098
1099	mutex_lock(&inet_diag_table_mutex);
1100	err = -EEXIST;
1101	if (!inet_diag_table[type]) {
1102		inet_diag_table[type] = h;
1103		err = 0;
1104	}
1105	mutex_unlock(&inet_diag_table_mutex);
1106out:
1107	return err;
1108}
1109EXPORT_SYMBOL_GPL(inet_diag_register);
1110
1111void inet_diag_unregister(const struct inet_diag_handler *h)
1112{
1113	const __u16 type = h->idiag_type;
1114
1115	if (type >= IPPROTO_MAX)
1116		return;
1117
1118	mutex_lock(&inet_diag_table_mutex);
1119	inet_diag_table[type] = NULL;
1120	mutex_unlock(&inet_diag_table_mutex);
1121}
1122EXPORT_SYMBOL_GPL(inet_diag_unregister);
1123
1124static int __init inet_diag_init(void)
1125{
1126	const int inet_diag_table_size = (IPPROTO_MAX *
1127					  sizeof(struct inet_diag_handler *));
1128	int err = -ENOMEM;
1129
1130	inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1131	if (!inet_diag_table)
1132		goto out;
1133
1134	err = sock_diag_register(&inet_diag_handler);
1135	if (err)
1136		goto out_free_nl;
1137
1138	err = sock_diag_register(&inet6_diag_handler);
1139	if (err)
1140		goto out_free_inet;
1141
1142	sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1143out:
1144	return err;
1145
1146out_free_inet:
1147	sock_diag_unregister(&inet_diag_handler);
1148out_free_nl:
1149	kfree(inet_diag_table);
1150	goto out;
1151}
1152
1153static void __exit inet_diag_exit(void)
1154{
1155	sock_diag_unregister(&inet6_diag_handler);
1156	sock_diag_unregister(&inet_diag_handler);
1157	sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1158	kfree(inet_diag_table);
1159}
1160
1161module_init(inet_diag_init);
1162module_exit(inet_diag_exit);
1163MODULE_LICENSE("GPL");
1164MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1165MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1166