From 059df03e09569fc6c0af0508da795dcf75dea682 Mon Sep 17 00:00:00 2001 From: yuyulei Date: Thu, 18 May 2017 22:40:56 +0800 Subject: [PATCH] change allow_prot from map to array --- src/cmd/frps/control.go | 3 +- src/models/server/config.go | 103 ++++++++++++++++++++++++------------ 2 files changed, 71 insertions(+), 35 deletions(-) diff --git a/src/cmd/frps/control.go b/src/cmd/frps/control.go index d848ccd3..4d2b2d12 100644 --- a/src/cmd/frps/control.go +++ b/src/cmd/frps/control.go @@ -253,7 +253,8 @@ func doLogin(req *msg.ControlReq, c *conn.Conn) (ret int64, info string, s *serv // and PrivilegeMode is enabled if s.Type == "tcp" { if len(server.PrivilegeAllowPorts) != 0 { - _, ok := server.PrivilegeAllowPorts[s.ListenPort] + // TODO: once linstenPort used, should remove the port from privilege ports + ok := server.ContainsPort(server.PrivilegeAllowPorts, s.ListenPort) if !ok { info = fmt.Sprintf("ProxyName [%s], remote_port [%d] isn't allowed", req.ProxyName, s.ListenPort) log.Warn(info) diff --git a/src/models/server/config.go b/src/models/server/config.go index 048e5659..06bcc439 100644 --- a/src/models/server/config.go +++ b/src/models/server/config.go @@ -49,7 +49,7 @@ var ( SubDomainHost string = "" // if PrivilegeAllowPorts is not nil, tcp proxies which remote port exist in this map can be connected - PrivilegeAllowPorts map[int64]struct{} + PrivilegeAllowPorts [][2]int64 MaxPoolCount int64 = 100 HeartBeatTimeout int64 = 30 UserConnTimeout int64 = 10 @@ -179,40 +179,12 @@ func loadCommonConf(confFile string) error { return fmt.Errorf("Parse conf error: privilege_token must be set if privilege_mode is enabled") } - PrivilegeAllowPorts = make(map[int64]struct{}) - tmpStr, ok = conf.Get("common", "privilege_allow_ports") + allowPortsStr, ok := conf.Get("common", "privilege_allow_ports") + // TODO: check if conflicts exist in port ranges if ok { - // for example: 1000-2000,2001,2002,3000-4000 - portRanges := strings.Split(tmpStr, ",") - for _, portRangeStr := range portRanges { - // 1000-2000 or 2001 - portArray := strings.Split(portRangeStr, "-") - // lenght: only 1 or 2 is correct - rangeType := len(portArray) - if rangeType == 1 { - singlePort, err := strconv.ParseInt(portArray[0], 10, 64) - if err != nil { - return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) - } - PrivilegeAllowPorts[singlePort] = struct{}{} - } else if rangeType == 2 { - min, err := strconv.ParseInt(portArray[0], 10, 64) - if err != nil { - return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) - } - max, err := strconv.ParseInt(portArray[1], 10, 64) - if err != nil { - return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) - } - if max < min { - return fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") - } - for i := min; i <= max; i++ { - PrivilegeAllowPorts[i] = struct{}{} - } - } else { - return fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") - } + PrivilegeAllowPorts, err = getPortRanges(allowPortsStr) + if err != nil { + return err } } } @@ -454,3 +426,66 @@ func GetProxyServer(proxyName string) (p *ProxyServer, ok bool) { p, ok = ProxyServers[proxyName] return } + +func getPortRanges(rangeStr string) (portRanges [][2]int64, err error) { + // for example: 1000-2000,2001,2002,3000-4000 + log.Debug("rangeStr is %s", rangeStr) + rangeArray := strings.Split(rangeStr, ",") + for _, portRangeStr := range rangeArray { + // 1000-2000 or 2001 + portArray := strings.Split(portRangeStr, "-") + // lenght: only 1 or 2 is correct + rangeType := len(portArray) + if rangeType == 1 { + singlePort, err := strconv.ParseInt(portArray[0], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + portRanges = append(portRanges, [2]int64{singlePort, singlePort}) + } else if rangeType == 2 { + min, err := strconv.ParseInt(portArray[0], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + max, err := strconv.ParseInt(portArray[1], 10, 64) + if err != nil { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect, %v", err) + } + if max < min { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports range incorrect") + } + portRanges = append(portRanges, [2]int64{min, max}) + } else { + return [][2]int64{}, fmt.Errorf("Parse conf error: privilege_allow_ports is incorrect") + } + } + return portRanges, nil +} + +func ContainsPort(portRanges [][2]int64, port int64) bool { + for _, pr := range portRanges { + if port >= pr[0] && port <= pr[1] { + return true + } + } + return false +} + +func PortRangesCut(portRanges [][2]int64, port int64) [][2]int64 { + var tmpRanges [][2]int64 + for _, pr := range portRanges { + if port >= pr[0] && port <= pr[1] { + leftRange := [2]int64{pr[0], port - 1} + rightRange := [2]int64{port + 1, pr[1]} + if leftRange[0] <= leftRange[1] { + tmpRanges = append(tmpRanges, leftRange) + } + if rightRange[0] <= rightRange[1] { + tmpRanges = append(tmpRanges, rightRange) + } + } else { + tmpRanges = append(tmpRanges, pr) + } + } + return tmpRanges +}