From 4b54bde677fbebba4f51aba4f873aa93aab8b42b Mon Sep 17 00:00:00 2001 From: lvxiu_ext Date: Fri, 12 May 2023 13:15:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=97=A5=E5=BF=97=E4=BC=98=E5=8C=96=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=8A=B6=E6=80=81=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 ++- main.go | 25 ++++++++++++++++++++----- service.go | 1 + 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index e996274..8dc8474 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -gin.log \ No newline at end of file +gin.log +/logs \ No newline at end of file diff --git a/main.go b/main.go index 387a4a3..c3ad534 100644 --- a/main.go +++ b/main.go @@ -207,6 +207,7 @@ func embeddings(c *gin.Context) { proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { + log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -239,6 +240,7 @@ func embeddings(c *gin.Context) { var chatResponse ChatResponse body, err := ioutil.ReadAll(resp.Body) if err != nil { + log.Printf("读取返回数据出错: %v", err) return err } json.Unmarshal(body, &chatResponse) @@ -287,6 +289,7 @@ func completions(c *gin.Context) { proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { + log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -311,7 +314,10 @@ func completions(c *gin.Context) { req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) - log.Printf("开始处理返回逻辑 %d", string(sss)) + if err != nil { + log.Printf("chatRequest 转化出错 %v", err) + } + log.Printf("开始处理返回逻辑: %v", string(sss)) if chatRequest.Stream { // 流式回应,处理 proxy.ModifyResponse = func(resp *http.Response) error { @@ -329,7 +335,7 @@ func completions(c *gin.Context) { if err == io.EOF { break } - log.Printf("httpError1 %v", err.Error()) + log.Printf("流式回应,处理 err %v:", err.Error()) break //return err } @@ -338,7 +344,6 @@ func completions(c *gin.Context) { var trimStr = strings.Trim(string(chunk), "data: ") if trimStr != "\n" { json.Unmarshal([]byte(trimStr), &chatResponse) - //log.Printf("trimStr:" + trimStr) if chatResponse.Choices != nil { reqContent += chatResponse.Choices[0].Delta.Content chatRequestId = chatResponse.Id @@ -347,7 +352,7 @@ func completions(c *gin.Context) { // 写回数据 _, err = c.Writer.Write([]byte(string(chunk) + "\n")) if err != nil { - log.Printf("httpError2 %v", err.Error()) + log.Printf("写回数据 err: %v", err.Error()) return err } c.Writer.(http.Flusher).Flush() @@ -372,6 +377,7 @@ func completions(c *gin.Context) { var chatResponse ChatResponse body, err := ioutil.ReadAll(resp.Body) if err != nil { + log.Printf("非流式回应,处理 err: %v", err) return err } json.Unmarshal(body, &chatResponse) @@ -464,6 +470,7 @@ func checkKeyMid() gin.HandlerFunc { log.Printf("key: %v", key) msg, err := checkKeyAndTimeCount(key) if err != nil { + log.Printf("checkKeyMid err: %v", err) c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()}) } } @@ -489,7 +496,7 @@ func main() { r := gin.Default() - //r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) + r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) r.GET("/v1/models", handleGetModels) r.OPTIONS("/v1/*path", handleOptions) @@ -557,5 +564,13 @@ func main() { //r.Run("127.0.0.1:8080") //docker下使用 + + // 定义一个GET请求测试接口 + r.GET("/ping", func(c *gin.Context) { + c.JSON(200, gin.H{ + "message": "pong from api2gpt", + }) + }) + r.Run("0.0.0.0:8080") } diff --git a/service.go b/service.go index 8d6cc37..e61f211 100644 --- a/service.go +++ b/service.go @@ -123,6 +123,7 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens // 余额消费日志请求 result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id) if err != nil { + log.Printf("余额消费日志请求失败 %v", err) return "", err } log.Printf("扣费KEY: %s 扣费token数: %d 扣费结果 %s", key, total_tokens, result)