From 504f565d3f6382c5d86189ac0d6ebbd24c2516c7 Mon Sep 17 00:00:00 2001 From: Guy Lewin Date: Sun, 16 Feb 2020 17:49:51 +0200 Subject: [PATCH] tcp multiplexing over http connect tunnel --- models/config/proxy.go | 16 +++++++- models/config/server_common.go | 18 +++++++++ server/controller/resource.go | 3 ++ server/proxy/tcp.go | 59 +++++++++++++++++++++++------- server/service.go | 17 +++++++++ utils/vhost/http.go | 10 ----- utils/vhost/resource.go | 14 +++++++ utils/vhost/tcp.go | 67 ++++++++++++++++++++++++++++++++++ utils/vhost/vhost.go | 10 +++++ 9 files changed, 188 insertions(+), 26 deletions(-) create mode 100644 utils/vhost/tcp.go diff --git a/models/config/proxy.go b/models/config/proxy.go index f4ddba50..dabcd65e 100644 --- a/models/config/proxy.go +++ b/models/config/proxy.go @@ -530,6 +530,7 @@ func (cfg *HealthCheckConf) checkForCli() error { type TcpProxyConf struct { BaseProxyConf BindInfoConf + DomainConf } func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool { @@ -539,7 +540,8 @@ func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool { } if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || - !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) { + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) { return false } return true @@ -548,6 +550,7 @@ func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool { func (cfg *TcpProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) cfg.BindInfoConf.UnmarshalFromMsg(pMsg) + cfg.DomainConf.UnmarshalFromMsg(pMsg) } func (cfg *TcpProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { @@ -557,12 +560,16 @@ func (cfg *TcpProxyConf) UnmarshalFromIni(prefix string, name string, section in if err = cfg.BindInfoConf.UnmarshalFromIni(prefix, name, section); err != nil { return } + if err = cfg.DomainConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } return } func (cfg *TcpProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { cfg.BaseProxyConf.MarshalToMsg(pMsg) cfg.BindInfoConf.MarshalToMsg(pMsg) + cfg.DomainConf.MarshalToMsg(pMsg) } func (cfg *TcpProxyConf) CheckForCli() (err error) { @@ -572,7 +579,12 @@ func (cfg *TcpProxyConf) CheckForCli() (err error) { return } -func (cfg *TcpProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { return nil } +func (cfg *TcpProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { + if len(cfg.CustomDomains) == 0 && cfg.SubDomain == "" && serverCfg.VhostTcpPort != 0 { + return fmt.Errorf("type [tcp] not support when vhost_http_port is on but no custom domain or subdomain configured") + } + return +} // UDP type UdpProxyConf struct { diff --git a/models/config/server_common.go b/models/config/server_common.go index 20e92f9c..c225c544 100644 --- a/models/config/server_common.go +++ b/models/config/server_common.go @@ -57,6 +57,12 @@ type ServerCommonConf struct { // requests. By default, this value is 0. VhostHttpsPort int `json:"vhost_https_port"` + // VhostTcpPort specifies the port that the server listens for TCP Vhost + // requests. If the value is 0, the server will not multiplex TCP requests + // on one single port. If it's not - it will listen on this value for HTTP + // CONNECT requests. By default, this value is 0. + VhostTcpPort int `json:"vhost_tcp_port"` + // VhostHttpTimeout specifies the response header timeout for the Vhost // HTTP server, in seconds. By default, this value is 60. VhostHttpTimeout int64 `json:"vhost_http_timeout"` @@ -156,6 +162,7 @@ func GetDefaultServerConf() ServerCommonConf { ProxyBindAddr: "0.0.0.0", VhostHttpPort: 0, VhostHttpsPort: 0, + VhostTcpPort: 0, VhostHttpTimeout: 60, DashboardAddr: "0.0.0.0", DashboardPort: 0, @@ -259,6 +266,17 @@ func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error cfg.VhostHttpsPort = 0 } + if tmpStr, ok = conf.Get("common", "vhost_tcp_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_tcp_port") + return + } else { + cfg.VhostTcpPort = int(v) + } + } else { + cfg.VhostTcpPort = 0 + } + if tmpStr, ok = conf.Get("common", "vhost_http_timeout"); ok { v, errRet := strconv.ParseInt(tmpStr, 10, 64) if errRet != nil || v < 0 { diff --git a/server/controller/resource.go b/server/controller/resource.go index 91332b57..9d9ed5c0 100644 --- a/server/controller/resource.go +++ b/server/controller/resource.go @@ -44,6 +44,9 @@ type ResourceController struct { // For https proxies, route requests to different clients by hostname and other information VhostHttpsMuxer *vhost.HttpsMuxer + // For tcp proxies, route requests to different proxies based on HTTP CONNECT header + VhostTcpMuxer *vhost.TcpMuxer + // Controller for nat hole connections NatHoleController *nathole.NatHoleController } diff --git a/server/proxy/tcp.go b/server/proxy/tcp.go index 0ecfe260..4fadb0e0 100644 --- a/server/proxy/tcp.go +++ b/server/proxy/tcp.go @@ -19,6 +19,7 @@ import ( "net" "github.com/fatedier/frp/models/config" + "github.com/fatedier/frp/utils/vhost" ) type TcpProxy struct { @@ -45,22 +46,52 @@ func (pxy *TcpProxy) Run() (remoteAddr string, err error) { pxy.listeners = append(pxy.listeners, l) xl.Info("tcp proxy listen port [%d] in group [%s]", pxy.cfg.RemotePort, pxy.cfg.Group) } else { - pxy.realPort, err = pxy.rc.TcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) - if err != nil { - return - } - defer func() { - if err != nil { - pxy.rc.TcpPortManager.Release(pxy.realPort) + if pxy.serverCfg.VhostTcpPort > 0 { + pxy.realPort = pxy.serverCfg.VhostTcpPort + routeConfig := &vhost.VhostRouteConfig{} + for _, domain := range pxy.cfg.CustomDomains { + if domain == "" { + continue + } + + routeConfig.Domain = domain + l, errRet := pxy.rc.VhostTcpMuxer.Listen(pxy.ctx, routeConfig) + if errRet != nil { + err = errRet + return + } + xl.Info("http tunnel server (tcp proxy) listen for host [%s]", routeConfig.Domain) + pxy.listeners = append(pxy.listeners, l) } - }() - listener, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort)) - if errRet != nil { - err = errRet - return + + if pxy.cfg.SubDomain != "" { + routeConfig.Domain = pxy.cfg.SubDomain + "." + pxy.serverCfg.SubDomainHost + l, errRet := pxy.rc.VhostTcpMuxer.Listen(pxy.ctx, routeConfig) + if errRet != nil { + err = errRet + return + } + xl.Info("http tunnel server (tcp proxy) listen for host [%s]", routeConfig.Domain) + pxy.listeners = append(pxy.listeners, l) + } + } else { + pxy.realPort, err = pxy.rc.TcpPortManager.Acquire(pxy.name, pxy.cfg.RemotePort) + if err != nil { + return + } + defer func() { + if err != nil { + pxy.rc.TcpPortManager.Release(pxy.realPort) + } + }() + listener, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort)) + if errRet != nil { + err = errRet + return + } + pxy.listeners = append(pxy.listeners, listener) + xl.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort) } - pxy.listeners = append(pxy.listeners, listener) - xl.Info("tcp proxy listen port [%d]", pxy.cfg.RemotePort) } pxy.cfg.RemotePort = pxy.realPort diff --git a/server/service.go b/server/service.go index 1ad7e281..7f969201 100644 --- a/server/service.go +++ b/server/service.go @@ -215,6 +215,23 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) { log.Info("https service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHttpsPort) } + // Create tcp vhost muxer. + if cfg.VhostTcpPort > 0 { + var l net.Listener + l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.VhostTcpPort)) + if err != nil { + err = fmt.Errorf("Create server listener error, %v", err) + return + } + + svr.rc.VhostTcpMuxer, err = vhost.NewTcpMuxer(l, 30*time.Second) + if err != nil { + err = fmt.Errorf("Create vhost tcpMuxer error, %v", err) + return + } + log.Info("tcp service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostTcpPort) + } + // frp tls listener svr.tlsListener = svr.muxer.Listen(1, 1, func(data []byte) bool { return int(data[0]) == frpNet.FRP_TLS_HEAD_BYTE diff --git a/utils/vhost/http.go b/utils/vhost/http.go index 7c2d7a3a..c651f806 100644 --- a/utils/vhost/http.go +++ b/utils/vhost/http.go @@ -34,16 +34,6 @@ var ( ErrNoDomain = errors.New("no such domain") ) -func getHostFromAddr(addr string) (host string) { - strs := strings.Split(addr, ":") - if len(strs) > 1 { - host = strs[0] - } else { - host = addr - } - return -} - type HttpReverseProxyOptions struct { ResponseHeaderTimeoutS int64 } diff --git a/utils/vhost/resource.go b/utils/vhost/resource.go index 5c084306..1ee77ce1 100644 --- a/utils/vhost/resource.go +++ b/utils/vhost/resource.go @@ -98,3 +98,17 @@ func noAuthResponse() *http.Response { } return res } + +func okResponse() *http.Response { + header := make(http.Header) + + res := &http.Response{ + Status: "OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: header, + } + return res +} diff --git a/utils/vhost/tcp.go b/utils/vhost/tcp.go new file mode 100644 index 00000000..1bde2cf0 --- /dev/null +++ b/utils/vhost/tcp.go @@ -0,0 +1,67 @@ +// Copyright 2019 guylewin, guy@lewin.co.il +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "time" +) + +type TcpMuxer struct { + *VhostMuxer +} + +func NewTcpMuxer(listener net.Listener, timeout time.Duration) (*TcpMuxer, error) { + mux, err := NewVhostMuxer(listener, getTcpServiceName, nil, sendHttpOk, timeout) + return &TcpMuxer{mux}, err +} + +func readHttpConnectRequest(rd io.Reader) (host string, err error) { + bufioReader := bufio.NewReader(rd) + + req, err := http.ReadRequest(bufioReader) + if err != nil { + return + } + + if req.Method != "CONNECT" { + err = fmt.Errorf("connections to tcp vhost must be of method CONNECT") + return + } + + host = getHostFromAddr(req.Host) + return +} + +func sendHttpOk(c net.Conn, _ string) (_ net.Conn, err error) { + okResp := okResponse() + err = okResp.Write(c) + return c, err +} + +func getTcpServiceName(c net.Conn) (_ net.Conn, _ map[string]string, err error) { + reqInfoMap := make(map[string]string, 0) + host, err := readHttpConnectRequest(c) + if err != nil { + return nil, reqInfoMap, err + } + reqInfoMap["Host"] = host + reqInfoMap["Scheme"] = "tcp" + return c, reqInfoMap, nil +} diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 57f82394..9650204c 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -225,3 +225,13 @@ func (l *Listener) Name() string { func (l *Listener) Addr() net.Addr { return (*net.TCPAddr)(nil) } + +func getHostFromAddr(addr string) (host string) { + strs := strings.Split(addr, ":") + if len(strs) > 1 { + host = strs[0] + } else { + host = addr + } + return +}