package controller import ( "api2gpt-mid/common" "api2gpt-mid/model" "api2gpt-mid/service" "bufio" "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" "log" "math/rand" "net/http" "net/http/httputil" "net/url" "strings" "time" ) func Images(c *gin.Context) { var imagesRequest model.ImagesRequest var modelStr = "images-generations" if err := c.ShouldBindJSON(&imagesRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } auth := c.Request.Header.Get("Authorization") key := auth[7:] serverInfo, err := service.CheckBlanceForImages(key, modelStr, imagesRequest.N) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, modelStr, err) return } log.Printf("请求的KEY: %v Model: %v", key, modelStr) remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(imagesRequest) if err != nil { log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } proxy.Director = func(req *http.Request) { req.Header = c.Request.Header req.Host = remote.Host req.URL.Scheme = remote.Scheme req.URL.Host = remote.Host req.URL.Path = c.Request.URL.Path req.ContentLength = int64(len(newReqBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "") serverKey := serverInfo.AvailableKey keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 source := rand.NewSource(time.Now().UnixNano()) random := rand.New(source) // 从数组中随机选择一个元素 serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(imagesRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) proxy.ModifyResponse = func(resp *http.Response) error { if resp.StatusCode != http.StatusOK { //退回预扣除的余额 err := service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N) if err != nil { return err } return nil } resp.Header.Set("Openai-Organization", "api2gpt") var imagesResponse model.ImagesResponse body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } err = json.Unmarshal(body, &imagesResponse) if err != nil { log.Printf("json解析数据出错: %v", err) return err } resp.Body = io.NopCloser(bytes.NewReader(body)) log.Printf("image size: %v", len(imagesResponse.Data)) timestamp := time.Now().Unix() timestampID := "img-" + fmt.Sprintf("%d", timestamp) //消费余额 _, err = service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID) if err != nil { return err } return nil } proxy.ServeHTTP(c.Writer, c.Request) } func Edit(c *gin.Context) { var promptTokens int var complateTokens int var totalTokens int var chatRequest model.ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } auth := c.Request.Header.Get("Authorization") key := auth[7:] //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) return } log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } proxy.Director = func(req *http.Request) { req.Header = c.Request.Header req.Host = remote.Host req.URL.Scheme = remote.Scheme req.URL.Host = remote.Host req.URL.Path = c.Request.URL.Path req.ContentLength = int64(len(newReqBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "") serverKey := serverInfo.AvailableKey keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 source := rand.NewSource(time.Now().UnixNano()) random := rand.New(source) // 从数组中随机选择一个元素 serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse model.ChatResponse body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } err = json.Unmarshal(body, &chatResponse) if err != nil { log.Printf("json解析数据出错: %v", err) return err } promptTokens = chatResponse.Usage.PromptTokens complateTokens = chatResponse.Usage.CompletionTokens totalTokens = chatResponse.Usage.TotalTokens resp.Body = io.NopCloser(bytes.NewReader(body)) log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", promptTokens, complateTokens, totalTokens) timestamp := time.Now().Unix() timestampID := "edit-" + fmt.Sprintf("%d", timestamp) //消费余额 _, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID) if err != nil { return err } return nil } proxy.ServeHTTP(c.Writer, c.Request) } func Embeddings(c *gin.Context) { var promptTokens int var totalTokens int var chatRequest model.ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } auth := c.Request.Header.Get("Authorization") key := auth[7:] serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) return } log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } proxy.Director = func(req *http.Request) { req.Header = c.Request.Header req.Host = remote.Host req.URL.Scheme = remote.Scheme req.URL.Host = remote.Host req.URL.Path = c.Request.URL.Path req.ContentLength = int64(len(newReqBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "") serverKey := serverInfo.AvailableKey keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 source := rand.NewSource(time.Now().UnixNano()) random := rand.New(source) // 从数组中随机选择一个元素 serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse model.ChatResponse body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } err = json.Unmarshal(body, &chatResponse) if err != nil { log.Printf("json解析数据出错: %v", err) return err } promptTokens = chatResponse.Usage.PromptTokens totalTokens = chatResponse.Usage.TotalTokens resp.Body = io.NopCloser(bytes.NewReader(body)) log.Printf("prompt_tokens: %v total_tokens: %v", promptTokens, totalTokens) timestamp := time.Now().Unix() timestampID := "emb-" + fmt.Sprintf("%d", timestamp) //消费余额 _, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID) if err != nil { return err } return nil } proxy.ServeHTTP(c.Writer, c.Request) } func Completions(c *gin.Context) { var promptTokens int var completionTokens int var totalTokens int var chatRequest model.ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } auth := c.Request.Header.Get("Authorization") key := auth[7:] serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) return } log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } proxy.Director = func(req *http.Request) { req.Header = c.Request.Header req.Host = remote.Host req.URL.Scheme = remote.Scheme req.URL.Host = remote.Host req.URL.Path = c.Request.URL.Path req.ContentLength = int64(len(newReqBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "") serverKey := serverInfo.AvailableKey keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 source := rand.NewSource(time.Now().UnixNano()) random := rand.New(source) // 从数组中随机选择一个元素 serverKey = keyList[random.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) req.Body = io.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) if err != nil { log.Printf("chatRequest 转化出错 %v", err) } log.Printf("开始处理返回逻辑: %v", string(sss)) if chatRequest.Stream { // 流式回应,处理 proxy.ModifyResponse = func(resp *http.Response) error { log.Printf("流式回应 http status code: %v", resp.StatusCode) if resp.StatusCode != http.StatusOK { //退回预扣除的余额 err = service.CheckBlanceReturn(key, chatRequest.Model, chatRequest.MaxTokens) if err != nil { return err } return nil } chatRequestId := "" reqContent := "" reader := bufio.NewReader(resp.Body) headers := resp.Header for k, v := range headers { c.Writer.Header().Set(k, v[0]) } c.Writer.Header().Set("Openai-Organization", "api2gpt") for { chunk, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } log.Printf("流式回应,处理 err %v:", err.Error()) break //return err } var chatResponse model.ChatResponse //去除回应中的data:前缀 var trimStr = strings.Trim(string(chunk), "data: ") if trimStr != "\n" { err := json.Unmarshal([]byte(trimStr), &chatResponse) if err != nil { return err } if chatResponse.Choices != nil { if chatResponse.Choices[0].Text != "" { reqContent += chatResponse.Choices[0].Text } else { reqContent += chatResponse.Choices[0].Delta.Content } chatRequestId = chatResponse.Id } // 写回数据 _, err = c.Writer.Write([]byte(string(chunk) + "\n")) if err != nil { log.Printf("写回数据 err: %v", err.Error()) return err } c.Writer.(http.Flusher).Flush() } } if chatRequest.Model == "text-davinci-003" { promptTokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model) } else { promptTokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model) } completionTokens = common.NumTokensFromString(reqContent, chatRequest.Model) log.Printf("返回内容:%v", reqContent) totalTokens = promptTokens + completionTokens log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens) //消费余额 _, err := service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatRequestId) if err != nil { return err } return nil } } else { // 非流式回应,处理 proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse model.ChatResponse body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("非流式回应,处理 err: %v", err) return err } err = json.Unmarshal(body, &chatResponse) if err != nil { return err } promptTokens = chatResponse.Usage.PromptTokens completionTokens = chatResponse.Usage.CompletionTokens totalTokens = chatResponse.Usage.TotalTokens resp.Body = io.NopCloser(bytes.NewReader(body)) log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens) //消费余额 _, err = service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatResponse.Id) if err != nil { return err } return nil } } proxy.ServeHTTP(c.Writer, c.Request) } // HandleOptions 针对接口预检OPTIONS的处理 func HandleOptions(c *gin.Context) { // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") c.Status(200) return }