1/*
2 * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License version 2 as
6 * published by the Free Software Foundation.
7 *
8 * Development of this code funded by Astaro AG (http://www.astaro.com/)
9 */
10
11#include <linux/kernel.h>
12#include <linux/init.h>
13#include <linux/module.h>
14#include <linux/netlink.h>
15#include <linux/netfilter.h>
16#include <linux/netfilter/nf_tables.h>
17#include <net/netfilter/nf_tables_core.h>
18#include <net/netfilter/nf_tables.h>
19
20struct nft_cmp_expr {
21	struct nft_data		data;
22	enum nft_registers	sreg:8;
23	u8			len;
24	enum nft_cmp_ops	op:8;
25};
26
27static void nft_cmp_eval(const struct nft_expr *expr,
28			 struct nft_regs *regs,
29			 const struct nft_pktinfo *pkt)
30{
31	const struct nft_cmp_expr *priv = nft_expr_priv(expr);
32	int d;
33
34	d = memcmp(&regs->data[priv->sreg], &priv->data, priv->len);
35	switch (priv->op) {
36	case NFT_CMP_EQ:
37		if (d != 0)
38			goto mismatch;
39		break;
40	case NFT_CMP_NEQ:
41		if (d == 0)
42			goto mismatch;
43		break;
44	case NFT_CMP_LT:
45		if (d == 0)
46			goto mismatch;
47	case NFT_CMP_LTE:
48		if (d > 0)
49			goto mismatch;
50		break;
51	case NFT_CMP_GT:
52		if (d == 0)
53			goto mismatch;
54	case NFT_CMP_GTE:
55		if (d < 0)
56			goto mismatch;
57		break;
58	}
59	return;
60
61mismatch:
62	regs->verdict.code = NFT_BREAK;
63}
64
65static const struct nla_policy nft_cmp_policy[NFTA_CMP_MAX + 1] = {
66	[NFTA_CMP_SREG]		= { .type = NLA_U32 },
67	[NFTA_CMP_OP]		= { .type = NLA_U32 },
68	[NFTA_CMP_DATA]		= { .type = NLA_NESTED },
69};
70
71static int nft_cmp_init(const struct nft_ctx *ctx, const struct nft_expr *expr,
72			const struct nlattr * const tb[])
73{
74	struct nft_cmp_expr *priv = nft_expr_priv(expr);
75	struct nft_data_desc desc;
76	int err;
77
78	err = nft_data_init(NULL, &priv->data, sizeof(priv->data), &desc,
79			    tb[NFTA_CMP_DATA]);
80	BUG_ON(err < 0);
81
82	priv->sreg = nft_parse_register(tb[NFTA_CMP_SREG]);
83	err = nft_validate_register_load(priv->sreg, desc.len);
84	if (err < 0)
85		return err;
86
87	priv->op  = ntohl(nla_get_be32(tb[NFTA_CMP_OP]));
88	priv->len = desc.len;
89	return 0;
90}
91
92static int nft_cmp_dump(struct sk_buff *skb, const struct nft_expr *expr)
93{
94	const struct nft_cmp_expr *priv = nft_expr_priv(expr);
95
96	if (nft_dump_register(skb, NFTA_CMP_SREG, priv->sreg))
97		goto nla_put_failure;
98	if (nla_put_be32(skb, NFTA_CMP_OP, htonl(priv->op)))
99		goto nla_put_failure;
100
101	if (nft_data_dump(skb, NFTA_CMP_DATA, &priv->data,
102			  NFT_DATA_VALUE, priv->len) < 0)
103		goto nla_put_failure;
104	return 0;
105
106nla_put_failure:
107	return -1;
108}
109
110static struct nft_expr_type nft_cmp_type;
111static const struct nft_expr_ops nft_cmp_ops = {
112	.type		= &nft_cmp_type,
113	.size		= NFT_EXPR_SIZE(sizeof(struct nft_cmp_expr)),
114	.eval		= nft_cmp_eval,
115	.init		= nft_cmp_init,
116	.dump		= nft_cmp_dump,
117};
118
119static int nft_cmp_fast_init(const struct nft_ctx *ctx,
120			     const struct nft_expr *expr,
121			     const struct nlattr * const tb[])
122{
123	struct nft_cmp_fast_expr *priv = nft_expr_priv(expr);
124	struct nft_data_desc desc;
125	struct nft_data data;
126	u32 mask;
127	int err;
128
129	err = nft_data_init(NULL, &data, sizeof(data), &desc,
130			    tb[NFTA_CMP_DATA]);
131	BUG_ON(err < 0);
132
133	priv->sreg = nft_parse_register(tb[NFTA_CMP_SREG]);
134	err = nft_validate_register_load(priv->sreg, desc.len);
135	if (err < 0)
136		return err;
137
138	desc.len *= BITS_PER_BYTE;
139	mask = nft_cmp_fast_mask(desc.len);
140
141	priv->data = data.data[0] & mask;
142	priv->len  = desc.len;
143	return 0;
144}
145
146static int nft_cmp_fast_dump(struct sk_buff *skb, const struct nft_expr *expr)
147{
148	const struct nft_cmp_fast_expr *priv = nft_expr_priv(expr);
149	struct nft_data data;
150
151	if (nft_dump_register(skb, NFTA_CMP_SREG, priv->sreg))
152		goto nla_put_failure;
153	if (nla_put_be32(skb, NFTA_CMP_OP, htonl(NFT_CMP_EQ)))
154		goto nla_put_failure;
155
156	data.data[0] = priv->data;
157	if (nft_data_dump(skb, NFTA_CMP_DATA, &data,
158			  NFT_DATA_VALUE, priv->len / BITS_PER_BYTE) < 0)
159		goto nla_put_failure;
160	return 0;
161
162nla_put_failure:
163	return -1;
164}
165
166const struct nft_expr_ops nft_cmp_fast_ops = {
167	.type		= &nft_cmp_type,
168	.size		= NFT_EXPR_SIZE(sizeof(struct nft_cmp_fast_expr)),
169	.eval		= NULL,	/* inlined */
170	.init		= nft_cmp_fast_init,
171	.dump		= nft_cmp_fast_dump,
172};
173
174static const struct nft_expr_ops *
175nft_cmp_select_ops(const struct nft_ctx *ctx, const struct nlattr * const tb[])
176{
177	struct nft_data_desc desc;
178	struct nft_data data;
179	enum nft_cmp_ops op;
180	int err;
181
182	if (tb[NFTA_CMP_SREG] == NULL ||
183	    tb[NFTA_CMP_OP] == NULL ||
184	    tb[NFTA_CMP_DATA] == NULL)
185		return ERR_PTR(-EINVAL);
186
187	op = ntohl(nla_get_be32(tb[NFTA_CMP_OP]));
188	switch (op) {
189	case NFT_CMP_EQ:
190	case NFT_CMP_NEQ:
191	case NFT_CMP_LT:
192	case NFT_CMP_LTE:
193	case NFT_CMP_GT:
194	case NFT_CMP_GTE:
195		break;
196	default:
197		return ERR_PTR(-EINVAL);
198	}
199
200	err = nft_data_init(NULL, &data, sizeof(data), &desc,
201			    tb[NFTA_CMP_DATA]);
202	if (err < 0)
203		return ERR_PTR(err);
204
205	if (desc.len <= sizeof(u32) && op == NFT_CMP_EQ)
206		return &nft_cmp_fast_ops;
207	else
208		return &nft_cmp_ops;
209}
210
211static struct nft_expr_type nft_cmp_type __read_mostly = {
212	.name		= "cmp",
213	.select_ops	= nft_cmp_select_ops,
214	.policy		= nft_cmp_policy,
215	.maxattr	= NFTA_CMP_MAX,
216	.owner		= THIS_MODULE,
217};
218
219int __init nft_cmp_module_init(void)
220{
221	return nft_register_expr(&nft_cmp_type);
222}
223
224void nft_cmp_module_exit(void)
225{
226	nft_unregister_expr(&nft_cmp_type);
227}
228