From 151f2e89b217bccca8f3ac8c9f83f720de1a98b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BE=E9=87=8C=28barry=29?= Date: Fri, 12 Jul 2024 17:25:45 +0800 Subject: [PATCH] fix: barry 2024-07-12 17:25:45 --- server/service.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/server/service.go b/server/service.go index 5b0ba8a7..29a99c0c 100644 --- a/server/service.go +++ b/server/service.go @@ -65,6 +65,7 @@ const ( forwardHost = "remote.agi7.ai" forwardCookieName = "agi7.forward.auth" sseName = "proxy_status" + setCookieHeader = "x-agi7-set-cookie" ) func init() { @@ -687,12 +688,41 @@ const ( AuthFailed = 2 ) +type Cookie struct { + Name string `json:"name"` + Value string `json:"value"` + ExpiredAt string `json:"expiredAt"` +} + func (m authMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if !strings.HasSuffix(request.Host, forwardHost) { m.next.ServeHTTP(writer, request) return } + setCookie := strings.TrimSpace(request.Header.Get(setCookieHeader)) + if setCookie != "" { + var cc Cookie + if err := json.Unmarshal([]byte(setCookie), &cc); err != nil { + log.Errorf("failed to decode cookie json data, cookie=%s", setCookie) + } + + var expiredAt = time.Now().Add(time.Hour) + if ee, err := strconv.ParseInt(cc.ExpiredAt, 10, 64); err == nil { + expiredAt = time.Unix(ee, 0) + } + + http.SetCookie(writer, &http.Cookie{ + Name: cc.Name, + Value: cc.Value, + Path: "/", + Domain: request.Host, + Expires: expiredAt, + }) + writer.Write([]byte("ok")) + return + } + var domain = strings.SplitN(request.Host, ".", 2)[0] var cookieName = fmt.Sprintf("%s.%s", forwardCookieName, domain) cookie, err := request.Cookie(cookieName) @@ -721,7 +751,8 @@ func (m authMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Requ cookieData := request.Header.Get("Cookie") var cc string for _, v := range strings.Split(cookieData, ";") { - if strings.HasPrefix(v, cookieName) { + v = strings.TrimSpace(v) + if strings.HasPrefix(v, forwardCookieName+".") { continue } cc += v + ";"