From d15a73f1d9dfa118b423a8b824c7d166c0ba77d1 Mon Sep 17 00:00:00 2001 From: Kelvin Date: Wed, 24 May 2023 16:51:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0edit=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=92=8C=E5=9B=BE=E7=89=87=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 194 ++++++++++++++++++++++++++++++++++++++++++++++++++++- service.go | 105 +++++++++++++++++++++++++++++ 2 files changed, 298 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index e71ba83..649c0c2 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,14 @@ type ChatRequest struct { Messages []Message `json:"messages,omitempty"` Prompt string `json:"prompt,omitempty"` Input string `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + N int `json:"n,omitempty"` +} + +type ImagesRequest struct { + Prompt string `json:"prompt,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` } type ChatResponse struct { @@ -46,6 +54,14 @@ type ChatResponse struct { Usage Usage `json:"usage,omitempty"` Choices []Choice `json:"choices,omitempty"` } +type ImagesResponse struct { + Created int64 `json:"created,omitempty"` + Data []ImagesData `json:"data,omitempty"` +} + +type ImagesData struct { + Url string `json:"url,omitempty"` +} type Usage struct { PromptTokens int `json:"prompt_tokens,omitempty"` @@ -184,6 +200,173 @@ func numTokensFromString(msg string, model string) int { } } +func images(c *gin.Context) { + var imagesRequest ImagesRequest + var model = "images-generations" + if err := c.ShouldBindJSON(&imagesRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + auth := c.Request.Header.Get("Authorization") + key := auth[7:] + + serverInfo, err := checkBlanceForImages(key, model, imagesRequest.N) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(imagesRequest) + if err != nil { + log.Printf("http request err: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(imagesRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + //退回预扣除的余额 + checkBlanceReturnForImages(key, model, imagesRequest.N) + return nil + } + resp.Header.Set("Openai-Organization", "api2gpt") + var imagesResponse ImagesResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("读取返回数据出错: %v", err) + return err + } + json.Unmarshal(body, &imagesResponse) + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("image size: %v", len(imagesResponse.Data)) + timestamp := time.Now().Unix() + timestampID := "img-" + fmt.Sprintf("%d", timestamp) + //消费余额 + consumptionForImages(key, model, imagesRequest.N, len(imagesResponse.Data), timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + +func edit(c *gin.Context) { + var prompt_tokens int + var complate_tokens int + var total_tokens int + var chatRequest ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := auth[7:] + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := checkBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + log.Printf("http request err: %v", err) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + resp.Header.Set("Openai-Organization", "api2gpt") + var chatResponse ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("读取返回数据出错: %v", err) + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + complate_tokens = chatResponse.Usage.CompletionTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens) + timestamp := time.Now().Unix() + timestampID := "edit-" + fmt.Sprintf("%d", timestamp) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + func embeddings(c *gin.Context) { var prompt_tokens int var total_tokens int @@ -420,7 +603,7 @@ func completions(c *gin.Context) { func handleGetModels(c *gin.Context) { // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3 //models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} - models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} + models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002", "text-davinci-edit-001"} result := ListModelResponse{ Object: "list", } @@ -525,6 +708,8 @@ func main() { r.POST("/v1/chat/completions", checkKeyMid(), completions) r.POST("/v1/completions", checkKeyMid(), completions) r.POST("/v1/embeddings", checkKeyMid(), embeddings) + r.POST("/v1/edits", checkKeyMid(), edit) + r.POST("/v1/images/generations", checkKeyMid(), images) r.POST("/mock1", mockBalanceInquiry) r.POST("/mock2", mockBalanceConsumption) @@ -567,6 +752,13 @@ func main() { // Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0) // Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0) // Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0) + // var modelInfo2 ModelInfo = ModelInfo{ + // ModelName: "images-generations", + // ModelPrice: 0.01, + // ModelPrepayment: 1000, + // } + // modelInfoStr2, _ := json.Marshal(&modelInfo2) + // Redis.Set(context.Background(), "model:images-generations", modelInfoStr2, 0) // var userInfo UserInfo = UserInfo{ // UID: "1", diff --git a/service.go b/service.go index 563372c..a4659e7 100644 --- a/service.go +++ b/service.go @@ -96,6 +96,46 @@ func checkBlance(key string, model string) (ServerInfo, error) { return serverInfo, nil } +// 余额查询 for images +func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) { + var serverInfo ServerInfo + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + //获取服务器信息 + serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + if err != nil { + return serverInfo, errors.New("服务器信息不存在") + } + err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) + if err != nil { + return serverInfo, errors.New("服务器信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return serverInfo, errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return serverInfo, errors.New("模型信息解析失败") + } + //计算余额-先扣除指定金额 + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", -(float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)).Result() + if err != nil { + return serverInfo, errors.New("余额计算失败") + } + log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)) + if balance < 0 { + Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice).Result() + return serverInfo, errors.New("用户余额不足") + } + + return serverInfo, nil +} + // 预扣返还 func checkBlanceReturn(key string, model string) error { var serverInfo ServerInfo @@ -127,6 +167,37 @@ func checkBlanceReturn(key string, model string) error { return nil } +// 预扣返还 for images +func checkBlanceReturnForImages(key string, model string, n int) error { + var serverInfo ServerInfo + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + //获取服务器信息 + serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + if err != nil { + return errors.New("服务器信息不存在") + } + err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) + if err != nil { + return errors.New("服务器信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return errors.New("模型信息解析失败") + } + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)).Result() + log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)) + return nil +} + // 余额消费 func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { //获取用户信息 @@ -160,3 +231,37 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens } return result, nil } + +// 余额消费 for images +func consumptionForImages(key string, model string, n int, dataNum int, msg_id string) (string, error) { + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + if err != nil { + return "", errors.New("用户信息不存在") + } + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + if err != nil { + return "", errors.New("用户信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return "", errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return "", errors.New("模型信息解析失败") + } + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice-(float64(1000*dataNum)*modelInfo.ModelPrice)).Result() + + // 余额消费日志请求 + result, err := balanceConsumption(key, model, 0, 1000*dataNum, 1000*dataNum, msg_id) + log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, 1000*dataNum, float64(1000*dataNum)*modelInfo.ModelPrice, result) + if err != nil { + log.Printf("%s 余额消费日志请求失败 %v", key, err) + return "", err + } + return result, nil +}