package main import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "io/ioutil" "log" "math/rand" "net/http" "net/http/httputil" "net/url" "os" "strings" "time" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" "github.com/redis/go-redis/v9" ) type Message struct { Role string `json:"role,omitempty"` Name string `json:"name,omitempty"` Content string `json:"content,omitempty"` } type ChatRequest struct { Stream bool `json:"stream,omitempty"` Model string `json:"model,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` Temperature float64 `json:"temperature,omitempty"` Messages []Message `json:"messages,omitempty"` Prompt string `json:"prompt,omitempty"` Input string `json:"input,omitempty"` } type ChatResponse struct { Id string `json:"id,omitempty"` Object string `json:"object,omitempty"` Created int64 `json:"created,omitempty"` Model string `json:"model,omitempty"` Usage Usage `json:"usage,omitempty"` Choices []Choice `json:"choices,omitempty"` } type Usage struct { PromptTokens int `json:"prompt_tokens,omitempty"` CompletionTokens int `json:"completion_tokens,omitempty"` TotalTokens int `json:"total_tokens,omitempty"` } type Choice struct { Message Message `json:"message,omitempty"` Delta Delta `json:"delta,omitempty"` FinishReason string `json:"finish_reason,omitempty"` Index int `json:"index,omitempty"` Text string `json:"text,omitempty"` } type Delta struct { Content string `json:"content,omitempty"` } type CreditSummary struct { Object string `json:"object"` TotalGranted float64 `json:"total_granted"` TotalUsed float64 `json:"total_used"` TotalRemaining float64 `json:"total_remaining"` } type ListModelResponse struct { Object string `json:"object"` Data []Model `json:"data"` } type Model struct { ID string `json:"id"` Object string `json:"object"` Created int `json:"created"` OwnedBy string `json:"owned_by"` Permission []ModelPermission `json:"permission"` Root string `json:"root"` Parent any `json:"parent"` } type ModelPermission struct { ID string `json:"id"` Object string `json:"object"` Created int `json:"created"` AllowCreateEngine bool `json:"allow_create_engine"` AllowSampling bool `json:"allow_sampling"` AllowLogprobs bool `json:"allow_logprobs"` AllowSearchIndices bool `json:"allow_search_indices"` AllowView bool `json:"allow_view"` AllowFineTuning bool `json:"allow_fine_tuning"` Organization string `json:"organization"` Group any `json:"group"` IsBlocking bool `json:"is_blocking"` } type UserInfo struct { UID string `json:"uid"` SID string `json:"sid"` } type ServerInfo struct { ServerAddress string `json:"server_address"` AvailableKey string `json:"available_key"` } type ModelInfo struct { ModelName string `json:"model_name"` ModelPrice float64 `json:"model_price"` ModelPrepayment int `json:"model_prepayment"` } var ( Redis *redis.Client RedisAddress = "localhost:6379" ) func init() { //gin.SetMode(gin.ReleaseMode) if v := os.Getenv("REDIS_ADDRESS"); v != "" { RedisAddress = v } log.Printf("loading redis address: %s", RedisAddress) } // redis 初始化 func InitRedis() *redis.Client { rdb := redis.NewClient(&redis.Options{ Addr: RedisAddress, Password: "", // no password set DB: 0, // use default DB PoolSize: 10, }) result := rdb.Ping(context.Background()) fmt.Println("redis ping:", result.Val()) if result.Val() != "PONG" { // 连接有问题 return nil } return rdb } // 计算Messages中的token数量 func numTokensFromMessages(messages []Message, model string) int { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("getEncoding: %v", err) panic(err) } numTokens := 0 for _, message := range messages { numTokens += len(tkm.Encode(message.Content, nil, nil)) numTokens += 6 } numTokens += 3 return numTokens } // 计算String中的token数量 func numTokensFromString(msg string, model string) int { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("getEncoding: %v", err) panic(err) } if model == "text-davinci-003" { return len(tkm.Encode(msg, nil, nil)) + 1 } else { return len(tkm.Encode(msg, nil, nil)) + 9 } } func embeddings(c *gin.Context) { var prompt_tokens int var total_tokens int var chatRequest ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } auth := c.Request.Header.Get("Authorization") key := strings.Trim(auth, "Bearer ") //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 serverInfo, err := checkBlance(key, chatRequest.Model) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) return } log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { panic(err) } proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } proxy.Director = func(req *http.Request) { req.Header = c.Request.Header req.Host = remote.Host req.URL.Scheme = remote.Scheme req.URL.Host = remote.Host req.URL.Path = c.Request.URL.Path req.ContentLength = int64(len(newReqBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "") serverKey := serverInfo.AvailableKey keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 rand.Seed(time.Now().UnixNano()) // 从数组中随机选择一个元素 serverKey = keyList[rand.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) log.Printf("开始处理返回逻辑 %d", string(sss)) proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse ChatResponse body, err := ioutil.ReadAll(resp.Body) if err != nil { log.Printf("读取返回数据出错: %v", err) return err } json.Unmarshal(body, &chatResponse) prompt_tokens = chatResponse.Usage.PromptTokens total_tokens = chatResponse.Usage.TotalTokens resp.Body = ioutil.NopCloser(bytes.NewReader(body)) log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens) timestamp := time.Now().Unix() timestampID := "emb-" + fmt.Sprintf("%d", timestamp) //消费余额 consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) return nil } proxy.ServeHTTP(c.Writer, c.Request) } func completions(c *gin.Context) { var prompt_tokens int var completion_tokens int var total_tokens int var chatRequest ChatRequest if err := c.ShouldBindJSON(&chatRequest); err != nil { c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) return } auth := c.Request.Header.Get("Authorization") key := strings.Trim(auth, "Bearer ") //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 serverInfo, err := checkBlance(key, chatRequest.Model) if err != nil { c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) return } log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) remote, err := url.Parse(serverInfo.ServerAddress) if err != nil { panic(err) } proxy := httputil.NewSingleHostReverseProxy(remote) newReqBody, err := json.Marshal(chatRequest) if err != nil { log.Printf("http request err: %v", err) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } proxy.Director = func(req *http.Request) { req.Header = c.Request.Header req.Host = remote.Host req.URL.Scheme = remote.Scheme req.URL.Host = remote.Host req.URL.Path = c.Request.URL.Path req.ContentLength = int64(len(newReqBody)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "") serverKey := serverInfo.AvailableKey keyList := strings.Split(serverInfo.AvailableKey, ",") if len(keyList) > 1 { // 随机数种子 rand.Seed(time.Now().UnixNano()) // 从数组中随机选择一个元素 serverKey = keyList[rand.Intn(len(keyList))] } req.Header.Set("Authorization", "Bearer "+serverKey) req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) } sss, err := json.Marshal(chatRequest) if err != nil { log.Printf("chatRequest 转化出错 %v", err) } log.Printf("开始处理返回逻辑: %v", string(sss)) if chatRequest.Stream { // 流式回应,处理 proxy.ModifyResponse = func(resp *http.Response) error { log.Printf("流式回应 http status code: %v", resp.StatusCode) if resp.StatusCode != http.StatusOK { //退回预扣除的余额 checkBlanceReturn(key, chatRequest.Model) return nil } chatRequestId := "" reqContent := "" reader := bufio.NewReader(resp.Body) headers := resp.Header for k, v := range headers { c.Writer.Header().Set(k, v[0]) } c.Writer.Header().Set("Openai-Organization", "api2gpt") for { chunk, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } log.Printf("流式回应,处理 err %v:", err.Error()) break //return err } var chatResponse ChatResponse //去除回应中的data:前缀 var trimStr = strings.Trim(string(chunk), "data: ") if trimStr != "\n" { json.Unmarshal([]byte(trimStr), &chatResponse) if chatResponse.Choices != nil { reqContent += chatResponse.Choices[0].Delta.Content chatRequestId = chatResponse.Id } // 写回数据 _, err = c.Writer.Write([]byte(string(chunk) + "\n")) if err != nil { log.Printf("写回数据 err: %v", err.Error()) return err } c.Writer.(http.Flusher).Flush() } } if chatRequest.Model == "text-davinci-003" { prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model) } else { prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model) } completion_tokens = numTokensFromString(reqContent, chatRequest.Model) total_tokens = prompt_tokens + completion_tokens log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) //消费余额 consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId) return nil } } else { // 非流式回应,处理 proxy.ModifyResponse = func(resp *http.Response) error { resp.Header.Set("Openai-Organization", "api2gpt") var chatResponse ChatResponse body, err := ioutil.ReadAll(resp.Body) if err != nil { log.Printf("非流式回应,处理 err: %v", err) return err } json.Unmarshal(body, &chatResponse) prompt_tokens = chatResponse.Usage.PromptTokens completion_tokens = chatResponse.Usage.CompletionTokens total_tokens = chatResponse.Usage.TotalTokens resp.Body = ioutil.NopCloser(bytes.NewReader(body)) log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) //消费余额 consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id) return nil } } proxy.ServeHTTP(c.Writer, c.Request) } // model查询列表 func handleGetModels(c *gin.Context) { // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3 //models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} result := ListModelResponse{ Object: "list", } for _, model := range models { result.Data = append(result.Data, Model{ ID: model, Object: "model", Created: 1677649963, OwnedBy: "openai", Permission: []ModelPermission{ { ID: "", Object: "model", Created: 1679602087, AllowCreateEngine: true, AllowSampling: true, AllowLogprobs: true, AllowSearchIndices: true, AllowView: true, AllowFineTuning: true, Organization: "*", Group: nil, IsBlocking: false, }, }, Root: model, Parent: nil, }) } c.JSON(200, result) } // 针对接口预检OPTIONS的处理 func handleOptions(c *gin.Context) { // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") c.Status(200) return } // 余额查询 func balance(c *gin.Context) { auth := c.Request.Header.Get("Authorization") key := strings.Trim(auth, "Bearer ") balance, err := queryBlance(key) if err != nil { c.JSON(400, gin.H{"error": err.Error()}) } var creditSummary CreditSummary creditSummary.Object = "credit_grant" creditSummary.TotalGranted = 999999 creditSummary.TotalUsed = 999999 - balance creditSummary.TotalRemaining = balance c.JSON(200, creditSummary) } // key 检测中间件 func checkKeyMid() gin.HandlerFunc { return func(c *gin.Context) { auth := c.Request.Header.Get("Authorization") log.Printf("auth: %v", auth) if auth == "" { c.AbortWithStatusJSON(401, gin.H{"code": 40001}) } else { key := strings.Trim(auth, "Bearer ") log.Printf("key: %v", key) msg, err := checkKeyAndTimeCount(key) if err != nil { log.Printf("checkKeyMid err: %v", err) c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()}) } } log.Printf("auth check end") // 执行函数 c.Next() } } func main() { // 禁用控制台颜色,将日志写入文件时不需要控制台颜色。 gin.DisableConsoleColor() // 记录到文件。 filename := time.Now().Format("20060102150405") + ".log" f, _ := os.Create("logs/gin" + filename) //gin.DefaultWriter = io.MultiWriter(f) // 如果需要同时将日志写入文件和控制台,请使用以下代码。 gin.DefaultWriter = io.MultiWriter(f, os.Stdout) log.SetOutput(gin.DefaultWriter) r := gin.Default() r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) r.GET("/v1/models", handleGetModels) r.OPTIONS("/v1/*path", handleOptions) r.POST("/v1/chat/completions", checkKeyMid(), completions) r.POST("/v1/completions", checkKeyMid(), completions) r.POST("/v1/embeddings", checkKeyMid(), embeddings) r.POST("/mock1", mockBalanceInquiry) r.POST("/mock2", mockBalanceConsumption) Redis = InitRedis() //添加reids测试数据 //var serverInfo ServerInfo = ServerInfo{ // ServerAddress: "https://gptp.any-door.cn", // AvailableKey: "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", //} // var serverInfo2 ServerInfo = ServerInfo{ // ServerAddress: "https://azure.any-door.cn", // AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119", // } //serverInfoStr, _ := json.Marshal(&serverInfo) // serverInfoStr2, _ := json.Marshal(&serverInfo2) //Redis.Set(context.Background(), "server:1", serverInfoStr, 0) // Redis.Set(context.Background(), "server:2", serverInfoStr2, 0) // var modelInfo ModelInfo = ModelInfo{ // ModelName: "gpt-3.5-turbo", // ModelPrice: 0.0001, // ModelPrepayment: 4000, // } // modelInfoStr, _ := json.Marshal(&modelInfo) // var modelInfo2 ModelInfo = ModelInfo{ // ModelName: "text-davinci-003", // ModelPrice: 0.001, // ModelPrepayment: 4000, // } // modelInfoStr2, _ := json.Marshal(&modelInfo2) // var modelInfo3 ModelInfo = ModelInfo{ // ModelName: "text-davinci-003", // ModelPrice: 0.001, // ModelPrepayment: 4000, // } // modelInfoStr3, _ := json.Marshal(&modelInfo3) // Redis.Set(context.Background(), "model:gpt-3.5-turbo", modelInfoStr, 0) // Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0) // Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0) // Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0) // var userInfo UserInfo = UserInfo{ // UID: "1", // SID: "1", // } // var userInfo2 UserInfo = UserInfo{ // UID: "2", // SID: "2", // } // userInfoStr, _ := json.Marshal(&userInfo) // userInfoStr2, _ := json.Marshal(&userInfo2) // Redis.Set(context.Background(), "user:8aeb3747-715c-48e8-8b80-aec815949f22", userInfoStr, 0) // Redis.Set(context.Background(), "user:AK-7d8ab782-a152-4cc1-9972-568713465c96", userInfoStr2, 0) // Redis.IncrByFloat(context.Background(), "user:1:balance", 1000).Result() // Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result() //r.Run("127.0.0.1:8080") //docker下使用 // 定义一个GET请求测试接口 r.GET("/ping", func(c *gin.Context) { c.JSON(200, gin.H{ "message": "pong from api2gpt", }) }) r.Run("0.0.0.0:8080") }