添加edit模型和图片模型
continuous-integration/drone/push Build is passing Details

main
Kelvin 3 years ago
parent 96334cc572
commit d15a73f1d9

@ -36,6 +36,14 @@ type ChatRequest struct {
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
Input string `json:"input,omitempty"` Input string `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
N int `json:"n,omitempty"`
}
type ImagesRequest struct {
Prompt string `json:"prompt,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
} }
type ChatResponse struct { type ChatResponse struct {
@ -46,6 +54,14 @@ type ChatResponse struct {
Usage Usage `json:"usage,omitempty"` Usage Usage `json:"usage,omitempty"`
Choices []Choice `json:"choices,omitempty"` Choices []Choice `json:"choices,omitempty"`
} }
type ImagesResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImagesData `json:"data,omitempty"`
}
type ImagesData struct {
Url string `json:"url,omitempty"`
}
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"` PromptTokens int `json:"prompt_tokens,omitempty"`
@ -184,6 +200,173 @@ func numTokensFromString(msg string, model string) int {
} }
} }
func images(c *gin.Context) {
var imagesRequest ImagesRequest
var model = "images-generations"
if err := c.ShouldBindJSON(&imagesRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
serverInfo, err := checkBlanceForImages(key, model, imagesRequest.N)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(imagesRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(imagesRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
//退回预扣除的余额
checkBlanceReturnForImages(key, model, imagesRequest.N)
return nil
}
resp.Header.Set("Openai-Organization", "api2gpt")
var imagesResponse ImagesResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &imagesResponse)
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("image size: %v", len(imagesResponse.Data))
timestamp := time.Now().Unix()
timestampID := "img-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumptionForImages(key, model, imagesRequest.N, len(imagesResponse.Data), timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func edit(c *gin.Context) {
var prompt_tokens int
var complate_tokens int
var total_tokens int
var chatRequest ChatRequest
if err := c.ShouldBindJSON(&chatRequest); err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
auth := c.Request.Header.Get("Authorization")
key := auth[7:]
//根据KEY调用用户余额接口判断是否有足够的余额 后期可考虑判断max_tokens参数来调整
serverInfo, err := checkBlance(key, chatRequest.Model)
if err != nil {
c.AbortWithStatusJSON(403, gin.H{"error": err.Error()})
log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err)
return
}
log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model)
remote, err := url.Parse(serverInfo.ServerAddress)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
proxy := httputil.NewSingleHostReverseProxy(remote)
newReqBody, err := json.Marshal(chatRequest)
if err != nil {
log.Printf("http request err: %v", err)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Host = remote.Host
req.URL.Scheme = remote.Scheme
req.URL.Host = remote.Host
req.URL.Path = c.Request.URL.Path
req.ContentLength = int64(len(newReqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "")
serverKey := serverInfo.AvailableKey
keyList := strings.Split(serverInfo.AvailableKey, ",")
if len(keyList) > 1 {
// 随机数种子
rand.Seed(time.Now().UnixNano())
// 从数组中随机选择一个元素
serverKey = keyList[rand.Intn(len(keyList))]
}
req.Header.Set("Authorization", "Bearer "+serverKey)
req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody))
}
sss, err := json.Marshal(chatRequest)
log.Printf("开始处理返回逻辑 %d", string(sss))
proxy.ModifyResponse = func(resp *http.Response) error {
resp.Header.Set("Openai-Organization", "api2gpt")
var chatResponse ChatResponse
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("读取返回数据出错: %v", err)
return err
}
json.Unmarshal(body, &chatResponse)
prompt_tokens = chatResponse.Usage.PromptTokens
complate_tokens = chatResponse.Usage.CompletionTokens
total_tokens = chatResponse.Usage.TotalTokens
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
log.Printf("prompt_tokens: %v complate_tokens: %v total_tokens: %v", prompt_tokens, complate_tokens, total_tokens)
timestamp := time.Now().Unix()
timestampID := "edit-" + fmt.Sprintf("%d", timestamp)
//消费余额
consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID)
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func embeddings(c *gin.Context) { func embeddings(c *gin.Context) {
var prompt_tokens int var prompt_tokens int
var total_tokens int var total_tokens int
@ -420,7 +603,7 @@ func completions(c *gin.Context) {
func handleGetModels(c *gin.Context) { func handleGetModels(c *gin.Context) {
// BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3 // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3
//models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} //models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"}
models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002", "text-davinci-edit-001"}
result := ListModelResponse{ result := ListModelResponse{
Object: "list", Object: "list",
} }
@ -525,6 +708,8 @@ func main() {
r.POST("/v1/chat/completions", checkKeyMid(), completions) r.POST("/v1/chat/completions", checkKeyMid(), completions)
r.POST("/v1/completions", checkKeyMid(), completions) r.POST("/v1/completions", checkKeyMid(), completions)
r.POST("/v1/embeddings", checkKeyMid(), embeddings) r.POST("/v1/embeddings", checkKeyMid(), embeddings)
r.POST("/v1/edits", checkKeyMid(), edit)
r.POST("/v1/images/generations", checkKeyMid(), images)
r.POST("/mock1", mockBalanceInquiry) r.POST("/mock1", mockBalanceInquiry)
r.POST("/mock2", mockBalanceConsumption) r.POST("/mock2", mockBalanceConsumption)
@ -567,6 +752,13 @@ func main() {
// Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0) // Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0)
// Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0) // Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0)
// Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0) // Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0)
// var modelInfo2 ModelInfo = ModelInfo{
// ModelName: "images-generations",
// ModelPrice: 0.01,
// ModelPrepayment: 1000,
// }
// modelInfoStr2, _ := json.Marshal(&modelInfo2)
// Redis.Set(context.Background(), "model:images-generations", modelInfoStr2, 0)
// var userInfo UserInfo = UserInfo{ // var userInfo UserInfo = UserInfo{
// UID: "1", // UID: "1",

@ -96,6 +96,46 @@ func checkBlance(key string, model string) (ServerInfo, error) {
return serverInfo, nil return serverInfo, nil
} }
// 余额查询 for images
func checkBlanceForImages(key string, model string, n int) (ServerInfo, error) {
var serverInfo ServerInfo
//获取用户信息
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取服务器信息
serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result()
if err != nil {
return serverInfo, errors.New("服务器信息不存在")
}
err = json.Unmarshal([]byte(serverInfoStr), &serverInfo)
if err != nil {
return serverInfo, errors.New("服务器信息解析失败")
}
//获取模型价格
modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result()
if err != nil {
return serverInfo, errors.New("模型信息不存在")
}
var modelInfo ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return serverInfo, errors.New("模型信息解析失败")
}
//计算余额-先扣除指定金额
balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", -(float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)).Result()
if err != nil {
return serverInfo, errors.New("余额计算失败")
}
log.Printf("用户余额 %f key: %v 预扣了:%f", balance, key, (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice))
if balance < 0 {
Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice).Result()
return serverInfo, errors.New("用户余额不足")
}
return serverInfo, nil
}
// 预扣返还 // 预扣返还
func checkBlanceReturn(key string, model string) error { func checkBlanceReturn(key string, model string) error {
var serverInfo ServerInfo var serverInfo ServerInfo
@ -127,6 +167,37 @@ func checkBlanceReturn(key string, model string) error {
return nil return nil
} }
// 预扣返还 for images
func checkBlanceReturnForImages(key string, model string, n int) error {
var serverInfo ServerInfo
//获取用户信息
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
//获取服务器信息
serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result()
if err != nil {
return errors.New("服务器信息不存在")
}
err = json.Unmarshal([]byte(serverInfoStr), &serverInfo)
if err != nil {
return errors.New("服务器信息解析失败")
}
//获取模型价格
modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result()
if err != nil {
return errors.New("模型信息不存在")
}
var modelInfo ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return errors.New("模型信息解析失败")
}
balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice)).Result()
log.Printf("用户余额 %f key: %v 返还预扣:%f", balance, key, (float64(modelInfo.ModelPrepayment*n) * modelInfo.ModelPrice))
return nil
}
// 余额消费 // 余额消费
func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) {
//获取用户信息 //获取用户信息
@ -160,3 +231,37 @@ func consumption(key string, model string, prompt_tokens int, completion_tokens
} }
return result, nil return result, nil
} }
// 余额消费 for images
func consumptionForImages(key string, model string, n int, dataNum int, msg_id string) (string, error) {
//获取用户信息
userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result()
if err != nil {
return "", errors.New("用户信息不存在")
}
var userInfo UserInfo
err = json.Unmarshal([]byte(userInfoStr), &userInfo)
if err != nil {
return "", errors.New("用户信息解析失败")
}
//获取模型价格
modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result()
if err != nil {
return "", errors.New("模型信息不存在")
}
var modelInfo ModelInfo
err = json.Unmarshal([]byte(modelPriceStr), &modelInfo)
if err != nil {
return "", errors.New("模型信息解析失败")
}
balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.UID+":balance", float64(modelInfo.ModelPrepayment*n)*modelInfo.ModelPrice-(float64(1000*dataNum)*modelInfo.ModelPrice)).Result()
// 余额消费日志请求
result, err := balanceConsumption(key, model, 0, 1000*dataNum, 1000*dataNum, msg_id)
log.Printf("用户余额:%f 扣费KEY: %s 扣费token数: %d 扣费:%f 扣费日志发送结果 %s", balance, key, 1000*dataNum, float64(1000*dataNum)*modelInfo.ModelPrice, result)
if err != nil {
log.Printf("%s 余额消费日志请求失败 %v", key, err)
return "", err
}
return result, nil
}

Loading…
Cancel
Save