You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

466 lines
14 KiB

package controller
import (
"api2gpt-mid/common"
"api2gpt-mid/model"
"api2gpt-mid/service"
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)
func Images(c *gin.Context) {
var imagesRequest model.ImagesRequest
var modelStr = "images-generations"
if err := c.ShouldBindJSON(&imagesRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlanceForImages(key, modelStr, imagesRequest.N)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
3 years ago
fmt.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, modelStr, err)
return
}
3 years ago
fmt.Printf("请求的KEY: %v Model: %v", key, modelStr)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(imagesRequest)
if err != nil {
3 years ago
fmt.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
3 years ago
fmt.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
err := service.CheckBlanceReturnForImages(key, modelStr, imagesRequest.N)
if err != nil {
return err
}
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse model.ImagesResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
3 years ago
fmt.Printf("读取返回数据出错: %v", err)
return err
}
err = json.Unmarshal(body, &imagesResponse)
if err != nil {
3 years ago
fmt.Printf("json解析数据出错: %v", err)
return err
}
resp.Body = io.NopCloser(bytes.NewReader(body))
3 years ago
fmt.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
_, err = service.ConsumptionForImages(key, modelStr, imagesRequest.N, len(imagesResponse.Data), timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Edit(c *gin.Context) {
var promptTokens int
var complateTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
3 years ago
fmt.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
3 years ago
fmt.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
3 years ago
fmt.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
3 years ago
fmt.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
3 years ago
fmt.Printf("读取返回数据出错: %v", err)
return err
}
err = json.Unmarshal(body, &chatResponse)
if err != nil {
3 years ago
fmt.Printf("json解析数据出错: %v", err)
return err
}
promptTokens = chatResponse.Usage.PromptTokens
complateTokens = chatResponse.Usage.CompletionTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
3 years ago
fmt.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", promptTokens, complateTokens, totalTokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Embeddings(c *gin.Context) {
var promptTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
3 years ago
fmt.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
3 years ago
fmt.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
3 years ago
fmt.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
3 years ago
fmt.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
3 years ago
fmt.Printf("读取返回数据出错: %v", err)
return err
}
err = json.Unmarshal(body, &chatResponse)
if err != nil {
3 years ago
fmt.Printf("json解析数据出错: %v", err)
return err
}
promptTokens = chatResponse.Usage.PromptTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
3 years ago
fmt.Printf("prompt_tokens: %v total_tokens: %v", promptTokens, totalTokens)
timestamp := time.Now().Unix()
timestampID := "emb-" + fmt.Sprintf("%d", timestamp)
//消费余额
_, err = service.Consumption(key, chatRequest.Model, promptTokens, 0, totalTokens, timestampID)
if err != nil {
return err
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func Completions(c *gin.Context) {
var promptTokens int
var completionTokens int
var totalTokens int
var chatRequest model.ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := service.CheckBlance(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
3 years ago
fmt.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
3 years ago
fmt.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
3 years ago
fmt.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
// 从数组中随机选择一个元素
serverKey = keyList[random.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = io.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
if err != nil {
3 years ago
fmt.Printf("chatRequest 转化出错 %v", err)
}
3 years ago
fmt.Printf("开始处理返回逻辑: %v", string(sss))
if chatRequest.Stream {
// 流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
3 years ago
fmt.Printf("流式回应 http status code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
err = service.CheckBlanceReturn(key, chatRequest.Model, chatRequest.MaxTokens)
if err != nil {
return err
}
return nil
}
chatRequestId := ""
reqContent := ""
reader := bufio.NewReader(resp.Body)
headers := resp.Header
for k, v := range headers {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Openai-Organization", "api2gpt")
for {
chunk, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
3 years ago
fmt.Printf("流式回应,处理 err %v:", err.Error())
break
//return err
}
var chatResponse model.ChatResponse
//去除回应中的data:前缀
var trimStr = strings.Trim(string(chunk), "data: ")
if trimStr != "\n" {
err := json.Unmarshal([]byte(trimStr), &chatResponse)
if err != nil {
return err
}
if chatResponse.Choices != nil {
if chatResponse.Choices[0].Text != "" {
reqContent += chatResponse.Choices[0].Text
} else {
reqContent += chatResponse.Choices[0].Delta.Content
}
chatRequestId = chatResponse.Id
}
// 写回数据
_, err = c.Writer.Write([]byte(string(chunk) + "\n"))
if err != nil {
3 years ago
fmt.Printf("写回数据 err: %v", err.Error())
return err
}
c.Writer.(http.Flusher).Flush()
}
}
if chatRequest.Model == "text-davinci-003" {
promptTokens = common.NumTokensFromString(chatRequest.Prompt, chatRequest.Model)
} else {
promptTokens = common.NumTokensFromMessages(chatRequest.Messages, chatRequest.Model)
}
completionTokens = common.NumTokensFromString(reqContent, chatRequest.Model)
3 years ago
fmt.Printf("返回内容:%v", reqContent)
totalTokens = promptTokens + completionTokens
3 years ago
fmt.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
//消费余额
_, err := service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatRequestId)
if err != nil {
return err
}
return nil
}
} else {
// 非流式回应,处理
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse model.ChatResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
3 years ago
fmt.Printf("非流式回应,处理 err: %v", err)
return err
}
err = json.Unmarshal(body, &chatResponse)
if err != nil {
return err
}
promptTokens = chatResponse.Usage.PromptTokens
completionTokens = chatResponse.Usage.CompletionTokens
totalTokens = chatResponse.Usage.TotalTokens
resp.Body = io.NopCloser(bytes.NewReader(body))
3 years ago
fmt.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", promptTokens, completionTokens, totalTokens)
//消费余额
_, err = service.Consumption(key, chatRequest.Model, promptTokens, completionTokens, totalTokens, chatResponse.Id)
if err != nil {
return err
}
return nil
}
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// HandleOptions 针对接口预检OPTIONS的处理
func HandleOptions(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
c.Status(200)
return
}