diff --git a/server/service.go b/server/service.go index 5b0ba8a7..35f75451 100644 --- a/server/service.go +++ b/server/service.go @@ -65,6 +65,7 @@ const ( forwardHost = "remote.agi7.ai" forwardCookieName = "agi7.forward.auth" sseName = "proxy_status" + setCookieHeader = "x-agi7-set-cookie" ) func init() { @@ -687,13 +688,79 @@ const ( AuthFailed = 2 ) +type Cookie struct { + Name string `json:"name"` + Value string `json:"value"` + ExpiredAt string `json:"expiredAt"` +} + func (m authMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if !strings.HasSuffix(request.Host, forwardHost) { m.next.ServeHTTP(writer, request) return } - var domain = strings.SplitN(request.Host, ".", 2)[0] + name := request.URL.Query().Get("name") + domain := strings.SplitN(request.Host, ".", 2)[0] + if name == fmt.Sprintf("%s.%s", forwardCookieName, domain) { + var expiredAt = time.Now().Add(time.Hour) + var expiredAtValue = request.URL.Query().Get("expiredAt") + if ee, err := strconv.ParseInt(expiredAtValue, 10, 64); err == nil { + expiredAt = time.Unix(ee, 0) + } else { + err = fmt.Errorf("failed to parse expiredAt field, expiredAt=%s", expiredAtValue) + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte(err.Error())) + return + } + + http.SetCookie(writer, &http.Cookie{ + Name: name, + Value: request.URL.Query().Get("value"), + Path: "/", + Domain: request.Host, + Expires: expiredAt, + }) + http.Redirect(writer, request, "/", http.StatusTemporaryRedirect) + return + } + + setCookie := strings.TrimSpace(request.Header.Get(setCookieHeader)) + if setCookie != "" { + var cc Cookie + if err := json.Unmarshal([]byte(setCookie), &cc); err != nil { + err = fmt.Errorf("failed to decode cookie json data, cookie=%s", setCookie) + log.Errorf(err.Error()) + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte(err.Error())) + return + } + + var expiredAt = time.Now().Add(time.Hour) + if ee, err := strconv.ParseInt(cc.ExpiredAt, 10, 64); err == nil { + expiredAt = time.Unix(ee, 0) + } else { + err = fmt.Errorf("failed to parse expiredAt field, expiredAt=%s", cc.ExpiredAt) + writer.WriteHeader(http.StatusBadRequest) + writer.Write([]byte(err.Error())) + return + } + + http.SetCookie(writer, &http.Cookie{ + Name: cc.Name, + Value: cc.Value, + Path: "/", + Domain: request.Host, + Expires: expiredAt, + Secure: true, + SameSite: http.SameSiteNoneMode, + HttpOnly: true, + }) + writer.WriteHeader(http.StatusOK) + writer.Write([]byte("ok")) + return + } + var cookieName = fmt.Sprintf("%s.%s", forwardCookieName, domain) cookie, err := request.Cookie(cookieName) if err != nil { @@ -721,7 +788,8 @@ func (m authMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Requ cookieData := request.Header.Get("Cookie") var cc string for _, v := range strings.Split(cookieData, ";") { - if strings.HasPrefix(v, cookieName) { + v = strings.TrimSpace(v) + if strings.HasPrefix(v, forwardCookieName+".") { continue } cc += v + ";"