diff --git a/api.go b/api.go index 2fe8c25..75fa846 100644 --- a/api.go +++ b/api.go @@ -90,13 +90,13 @@ func balanceConsumption(key string, model string, prompt_tokens int, completion_ client := &http.Client{} resp, err := client.Do(req2) if err != nil { - panic(err) + return "", err } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - panic(err) + return "", err } return string(body), nil } diff --git a/main.go b/main.go index 8ddee31..47b48ab 100644 --- a/main.go +++ b/main.go @@ -153,7 +153,7 @@ func numTokensFromMessages(messages []Message, model string) int { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("getEncoding: %v", err) - panic(err) + return 0 } numTokens := 0 for _, message := range messages { @@ -169,7 +169,7 @@ func numTokensFromString(msg string, model string) int { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("getEncoding: %v", err) - panic(err) + return 0 } if model == "text-davinci-003" { return len(tkm.Encode(msg, nil, nil)) + 1 @@ -201,7 +201,8 @@ func embeddings(c *gin.Context) { remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { - panic(err) + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return } proxy := httputil.NewSingleHostReverseProxy(remote) @@ -283,7 +284,8 @@ func completions(c *gin.Context) { remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { - panic(err) + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return } proxy := httputil.NewSingleHostReverseProxy(remote) diff --git a/service.go b/service.go index 50432c6..582d666 100644 --- a/service.go +++ b/service.go @@ -148,14 +148,14 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens if err != nil { return "", errors.New("模型信息解析失败") } - Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(total_tokens)*modelInfo.ModelPrice)).Result() + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(total_tokens)*modelInfo.ModelPrice)).Result() // 余额消费日志请求 result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id) + log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, total_tokens, float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(total_tokens)*modelInfo.ModelPrice), result) if err != nil { - log.Printf("余额消费日志请求失败 %v", err) + log.Printf("%s 余额消费日志请求失败 %v", key, err) return "", err } - log.Printf("扣费KEY: %s 扣费token数: %d 扣费结果 %s", key, total_tokens, result) return result, nil }