diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h index 9537a697b002..14a4c1f6fc5d 100755 --- a/include/linux/bpf-cgroup.h +++ b/include/linux/bpf-cgroup.h @@ -113,12 +113,38 @@ int __cgroup_bpf_run_filter_sock_ops(struct sock *sk, __ret; \ }) +#define BPF_CGROUP_RUN_SA_PROG_LOCK(sk, uaddr, type) \ +({ \ + int __ret = 0; \ + if (cgroup_bpf_enabled) { \ + lock_sock(sk); \ + __ret = __cgroup_bpf_run_filter_sock_addr(sk, uaddr, type); \ + release_sock(sk); \ + } \ + __ret; \ +}) + #define BPF_CGROUP_RUN_PROG_INET4_BIND(sk, uaddr) \ BPF_CGROUP_RUN_SA_PROG(sk, uaddr, BPF_CGROUP_INET4_BIND) #define BPF_CGROUP_RUN_PROG_INET6_BIND(sk, uaddr) \ BPF_CGROUP_RUN_SA_PROG(sk, uaddr, BPF_CGROUP_INET6_BIND) +#define BPF_CGROUP_PRE_CONNECT_ENABLED(sk) (cgroup_bpf_enabled && \ + sk->sk_prot->pre_connect) + +#define BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr) \ + BPF_CGROUP_RUN_SA_PROG(sk, uaddr, BPF_CGROUP_INET4_CONNECT) + +#define BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr) \ + BPF_CGROUP_RUN_SA_PROG(sk, uaddr, BPF_CGROUP_INET6_CONNECT) + +#define BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr) \ + BPF_CGROUP_RUN_SA_PROG_LOCK(sk, uaddr, BPF_CGROUP_INET4_CONNECT) + +#define BPF_CGROUP_RUN_PROG_INET6_CONNECT_LOCK(sk, uaddr) \ + BPF_CGROUP_RUN_SA_PROG_LOCK(sk, uaddr, BPF_CGROUP_INET6_CONNECT) + #define BPF_CGROUP_RUN_PROG_SOCK_OPS(sock_ops) \ ({ \ int __ret = 0; \ @@ -137,11 +163,16 @@ struct cgroup_bpf {}; static inline void cgroup_bpf_put(struct cgroup *cgrp) {} static inline int cgroup_bpf_inherit(struct cgroup *cgrp) { return 0; } +#define BPF_CGROUP_PRE_CONNECT_ENABLED(sk) (0) #define BPF_CGROUP_RUN_PROG_INET_INGRESS(sk,skb) ({ 0; }) #define BPF_CGROUP_RUN_PROG_INET_EGRESS(sk,skb) ({ 0; }) #define BPF_CGROUP_RUN_PROG_INET_SOCK(sk) ({ 0; }) #define BPF_CGROUP_RUN_PROG_INET4_BIND(sk, uaddr) ({ 0; }) #define BPF_CGROUP_RUN_PROG_INET6_BIND(sk, uaddr) ({ 0; }) +#define BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr) ({ 0; }) +#define BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr) ({ 0; }) +#define BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr) ({ 0; }) +#define BPF_CGROUP_RUN_PROG_INET6_CONNECT_LOCK(sk, uaddr) ({ 0; }) #define BPF_CGROUP_RUN_PROG_SOCK_OPS(sock_ops) ({ 0; }) #endif /* CONFIG_CGROUP_BPF */ diff --git a/include/net/addrconf.h b/include/net/addrconf.h index 4e5316a8fbf2..a940891f902e 100755 --- a/include/net/addrconf.h +++ b/include/net/addrconf.h @@ -235,6 +235,13 @@ struct ipv6_stub { }; extern const struct ipv6_stub *ipv6_stub __read_mostly; +/* A stub used by bpf helpers. Similarly ugly as ipv6_stub */ +struct ipv6_bpf_stub { + int (*inet6_bind)(struct sock *sk, struct sockaddr *uaddr, int addr_len, + bool force_bind_address_no_port, bool with_lock); +}; +extern const struct ipv6_bpf_stub *ipv6_bpf_stub __read_mostly; + /* * identify MLD packets for MLD filter exceptions */ diff --git a/include/net/sock.h b/include/net/sock.h index 5c987443029c..3da8670c2fe9 100755 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1054,6 +1054,9 @@ static inline void sk_prot_clear_nulls(struct sock *sk, int size) struct proto { void (*close)(struct sock *sk, long timeout); + int (*pre_connect)(struct sock *sk, + struct sockaddr *uaddr, + int addr_len); int (*connect)(struct sock *sk, struct sockaddr *uaddr, int addr_len); diff --git a/include/net/udp.h b/include/net/udp.h index 8058ca671c73..1ab7dcbe18b1 100755 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -282,6 +282,7 @@ void udp4_hwcsum(struct sk_buff *skb, __be32 src, __be32 dst); int udp_rcv(struct sk_buff *skb); int udp_ioctl(struct sock *sk, int cmd, unsigned long arg); int udp_init_sock(struct sock *sk); +int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len); int __udp_disconnect(struct sock *sk, int flags); int udp_disconnect(struct sock *sk, int flags); unsigned int udp_poll(struct file *file, struct socket *sock, poll_table *wait); diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index c47f76629456..7d7c209d4aba 100755 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -144,6 +144,8 @@ enum bpf_attach_type { BPF_SK_SKB_STREAM_VERDICT, BPF_CGROUP_INET4_BIND = 8, BPF_CGROUP_INET6_BIND = 9, + BPF_CGROUP_INET4_CONNECT, + BPF_CGROUP_INET6_CONNECT, __MAX_BPF_ATTACH_TYPE }; @@ -684,6 +686,13 @@ union bpf_attr { * Return * 0 on success, or a negative error in case of failure. * + * int bpf_bind(ctx, addr, addr_len) + * Bind socket to address. Only binding to IP is supported, no port can be + * set in addr. + * @ctx: pointer to context of type bpf_sock_addr + * @addr: pointer to struct sockaddr to bind socket to + * @addr_len: length of sockaddr structure + * Return: 0 on success or negative error code */ #define __BPF_FUNC_MAPPER(FN) \ FN(unspec), \ diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c index 852bfdaad239..abb95cb8e7f2 100755 --- a/kernel/bpf/syscall.c +++ b/kernel/bpf/syscall.c @@ -1112,6 +1112,8 @@ bpf_prog_load_check_attach_type(enum bpf_prog_type prog_type, switch (expected_attach_type) { case BPF_CGROUP_INET4_BIND: case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: return 0; default: return -EINVAL; @@ -1351,6 +1353,8 @@ static int bpf_prog_attach(const union bpf_attr *attr) break; case BPF_CGROUP_INET4_BIND: case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: ptype = BPF_PROG_TYPE_CGROUP_SOCK_ADDR; break; case BPF_CGROUP_SOCK_OPS: @@ -1412,6 +1416,8 @@ static int bpf_prog_detach(const union bpf_attr *attr) break; case BPF_CGROUP_INET4_BIND: case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: ptype = BPF_PROG_TYPE_CGROUP_SOCK_ADDR; break; case BPF_CGROUP_SOCK_OPS: @@ -1460,6 +1466,8 @@ static int bpf_prog_query(const union bpf_attr *attr, case BPF_CGROUP_INET_SOCK_CREATE: case BPF_CGROUP_INET4_BIND: case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: case BPF_CGROUP_SOCK_OPS: break; default: diff --git a/net/core/filter.c b/net/core/filter.c index 61af27e80030..776da31d55da 100755 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -3227,6 +3228,52 @@ static const struct bpf_func_proto bpf_setsockopt_proto = { .arg5_type = ARG_CONST_SIZE, }; +const struct ipv6_bpf_stub *ipv6_bpf_stub __read_mostly; +EXPORT_SYMBOL_GPL(ipv6_bpf_stub); + +BPF_CALL_3(bpf_bind, struct bpf_sock_addr_kern *, ctx, struct sockaddr *, addr, + int, addr_len) +{ +#ifdef CONFIG_INET + struct sock *sk = ctx->sk; + int err; + + /* Binding to port can be expensive so it's prohibited in the helper. + * Only binding to IP is supported. + */ + err = -EINVAL; + if (addr->sa_family == AF_INET) { + if (addr_len < sizeof(struct sockaddr_in)) + return err; + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + return err; + return __inet_bind(sk, addr, addr_len, true, false); +#if IS_ENABLED(CONFIG_IPV6) + } else if (addr->sa_family == AF_INET6) { + if (addr_len < SIN6_LEN_RFC2133) + return err; + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + return err; + /* ipv6_bpf_stub cannot be NULL, since it's called from + * bpf_cgroup_inet6_connect hook and ipv6 is already loaded + */ + return ipv6_bpf_stub->inet6_bind(sk, addr, addr_len, true, false); +#endif /* CONFIG_IPV6 */ + } +#endif /* CONFIG_INET */ + + return -EAFNOSUPPORT; +} + +static const struct bpf_func_proto bpf_bind_proto = { + .func = bpf_bind, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, +}; + static const struct bpf_func_proto * bpf_base_func_proto(enum bpf_func_id func_id) { @@ -3280,6 +3327,14 @@ sock_addr_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) */ case BPF_FUNC_get_current_uid_gid: return &bpf_get_current_uid_gid_proto; + case BPF_FUNC_bind: + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: + return &bpf_bind_proto; + default: + return NULL; + } default: return bpf_base_func_proto(func_id); } @@ -3758,6 +3813,7 @@ static bool sock_addr_is_valid_access(int off, int size, case bpf_ctx_range(struct bpf_sock_addr, user_ip4): switch (prog->expected_attach_type) { case BPF_CGROUP_INET4_BIND: + case BPF_CGROUP_INET4_CONNECT: break; default: return false; @@ -3766,6 +3822,7 @@ static bool sock_addr_is_valid_access(int off, int size, case bpf_ctx_range_till(struct bpf_sock_addr, user_ip6[0], user_ip6[3]): switch (prog->expected_attach_type) { case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET6_CONNECT: break; default: return false; diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 0483091ee32b..39ec51e96202 100755 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -545,12 +545,19 @@ int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags) { struct sock *sk = sock->sk; + int err; if (addr_len < sizeof(uaddr->sa_family)) return -EINVAL; if (uaddr->sa_family == AF_UNSPEC) return sk->sk_prot->disconnect(sk, flags); + if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) { + err = sk->sk_prot->pre_connect(sk, uaddr, addr_len); + if (err) + return err; + } + if (!inet_sk(sk)->inet_num && inet_autobind(sk)) return -EAGAIN; return sk->sk_prot->connect(sk, uaddr, addr_len); @@ -631,6 +638,12 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, if (sk->sk_state != TCP_CLOSE) goto out; + if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) { + err = sk->sk_prot->pre_connect(sk, uaddr, addr_len); + if (err) + goto out; + } + err = sk->sk_prot->connect(sk, uaddr, addr_len); if (err < 0) goto out; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 8762416fa53e..bf72aba6a2f1 100755 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -138,6 +138,21 @@ int tcp_twsk_unique(struct sock *sk, struct sock *sktw, void *twp) } EXPORT_SYMBOL_GPL(tcp_twsk_unique); +static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from tcp_v4_connect() and intended to + * prevent BPF program called below from accessing bytes that are out + * of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + sock_owned_by_me(sk); + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr); +} + /* This will initiate an outgoing connection. */ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) { @@ -2410,6 +2425,7 @@ struct proto tcp_prot = { .name = "TCP", .owner = THIS_MODULE, .close = tcp_close, + .pre_connect = tcp_v4_pre_connect, .connect = tcp_v4_connect, .disconnect = tcp_disconnect, .accept = inet_csk_accept, diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 75a16dc29e6f..27b8fddd845c 100755 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1761,6 +1761,19 @@ csum_copy_err: goto try_again; } +int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + /* This check is replicated from __ip4_datagram_connect() and + * intended to prevent BPF program called below from accessing bytes + * that are out of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr); +} +EXPORT_SYMBOL(udp_pre_connect); + int __udp_disconnect(struct sock *sk, int flags) { struct inet_sock *inet = inet_sk(sk); @@ -2847,6 +2860,7 @@ struct proto udp_prot = { .name = "UDP", .owner = THIS_MODULE, .close = udp_lib_close, + .pre_connect = udp_pre_connect, .connect = ip4_datagram_connect, .disconnect = udp_disconnect, .ioctl = udp_ioctl, diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index b1e5e68acf0f..95fa466f2b6c 100755 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -897,6 +897,10 @@ static const struct ipv6_stub ipv6_stub_impl = { .nd_tbl = &nd_tbl, }; +static const struct ipv6_bpf_stub ipv6_bpf_stub_impl = { + .inet6_bind = __inet6_bind, +}; + static int __init inet6_init(void) { struct list_head *r; @@ -1053,6 +1057,7 @@ static int __init inet6_init(void) /* ensure that ipv6 stubs are visible only after ipv6 is ready */ wmb(); ipv6_stub = &ipv6_stub_impl; + ipv6_bpf_stub = &ipv6_bpf_stub_impl; out: return err; diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index 20b4a14b9e1a..09e2f88be194 100755 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -115,6 +115,21 @@ static u32 tcp_v6_init_ts_off(const struct net *net, const struct sk_buff *skb) ipv6_hdr(skb)->saddr.s6_addr32); } +static int tcp_v6_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from tcp_v6_connect() and intended to + * prevent BPF program called below from accessing bytes that are out + * of the bound specified by user in addr_len. + */ + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + sock_owned_by_me(sk); + + return BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr); +} + static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) { @@ -1918,6 +1933,7 @@ struct proto tcpv6_prot = { .name = "TCPv6", .owner = THIS_MODULE, .close = tcp_close, + .pre_connect = tcp_v6_pre_connect, .connect = tcp_v6_connect, .disconnect = tcp_disconnect, .accept = inet_csk_accept, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index feedd8d6a568..6bb7b1cd97a1 100755 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -1019,6 +1019,25 @@ static void udp_v6_flush_pending_frames(struct sock *sk) } } +static int udpv6_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* The following checks are replicated from __ip6_datagram_connect() + * and intended to prevent BPF program called below from accessing + * bytes that are out of the bound specified by user in addr_len. + */ + if (uaddr->sa_family == AF_INET) { + if (__ipv6_only_sock(sk)) + return -EAFNOSUPPORT; + return udp_pre_connect(sk, uaddr, addr_len); + } + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET6_CONNECT_LOCK(sk, uaddr); +} + /** * udp6_hwcsum_outgoing - handle outgoing HW checksumming * @sk: socket we are sending on @@ -1601,6 +1620,7 @@ struct proto udpv6_prot = { .name = "UDPv6", .owner = THIS_MODULE, .close = udp_lib_close, + .pre_connect = udpv6_pre_connect, .connect = ip6_datagram_connect, .disconnect = udp_disconnect, .ioctl = udp_ioctl,