commit fd8cb3c976a83cb69a834811b8a46166e8366b32 Author: lvxiu_ext Date: Sat Apr 29 10:11:58 2023 +0800 init project diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e996274 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +gin.log \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a4ab424 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +# Build step +FROM golang:1.18 AS builder +ENV GOPROXY=https://goproxy.cn,direct +RUN mkdir -p /build +WORKDIR /build +COPY . . +RUN go build + +# Final step +FROM debian:buster-slim +RUN set -x && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + ca-certificates && \ + rm -rf /var/lib/apt/lists/* \ + +EXPOSE 8080 +WORKDIR /app +COPY --from=builder /build/main /app/api2u-go +ENTRYPOINT ["/app/api2u-go"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..2409e1f --- /dev/null +++ b/README.md @@ -0,0 +1,22 @@ +## docker 打包镜像 +docker stop api2u-go +docker rm api2u-go +docker build -t api2u-go . + +## docker 运行镜像 +docker run -d -p 8081:8080 --name=api2u-go --env REDIS_ADDRESS=172.17.0.1:6379 api2u-go +docker run -p 8081:8080 --name=api2u-go --env REDIS_ADDRESS=172.17.0.1:6379 api2u-go + + + + +## nginx 配置 +proxy_set_header Host $host; +proxy_set_header X-Real-IP $remote_addr; +proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +proxy_cache off; +proxy_cache_bypass $http_pragam; +proxy_cache_revalidate on; +proxy_http_version 1.1; +proxy_buffering off; +proxy_pass http://localhost:8081/; \ No newline at end of file diff --git a/api.go b/api.go new file mode 100644 index 0000000..4d6e567 --- /dev/null +++ b/api.go @@ -0,0 +1,102 @@ +package main + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + + "github.com/gin-gonic/gin" +) + +type BalanceInfo struct { + ServerAddress string `json:"server_address"` + AvailableKey string `json:"available_key"` + UserBalance float64 `json:"user_balance"` + TokenRatio float64 `json:"token_ratio"` +} + +type Consumption struct { + SecretKey string `json:"secretKey"` + Model string `json:"model"` + MsgId string `json:"msgId"` + PromptTokens int `json:"promptTokens"` + CompletionTokens int `json:"completionTokens"` + TotalTokens int `json:"totalTokens"` +} + +func mockBalanceInquiry(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "server_address": "https://gptp.any-door.cn", + "available_key": "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", + "user_balance": 10000, + "token_ratio": 1000, + }) +} + +func mockBalanceConsumption(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": "true", + }) +} + +// 余额查询api调用 +func balanceInquiry(key string, model string) (*BalanceInfo, error) { + url := "http://localhost:8080/mock1?key=" + key + "&model=" + model + req, err := http.NewRequest("POST", url, nil) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + panic(err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var balanceInfo BalanceInfo + if err := json.Unmarshal(body, &balanceInfo); err != nil { + return nil, err + } + + return &balanceInfo, nil +} + +// 余额消费 +func balanceConsumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { + var data = Consumption{ + SecretKey: key, + Model: model, + MsgId: msg_id, + PromptTokens: prompt_tokens, + CompletionTokens: completion_tokens, + TotalTokens: total_tokens, + } + + jsonData, err := json.Marshal(data) + // 构造post请求的body + reqBody := bytes.NewBuffer(jsonData) + + url := "http://121.4.100.155:8080/other/usageRecord" + req2, err := http.NewRequest("POST", url, reqBody) + + // 设置http请求的header + req2.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req2) + if err != nil { + panic(err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return string(body), nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..44d03b7 --- /dev/null +++ b/go.mod @@ -0,0 +1,37 @@ +module main + +go 1.20 + +require ( + github.com/bytedance/sonic v1.8.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/gin-gonic/gin v1.9.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.11.2 // indirect + github.com/go-redis/redis v6.15.9+incompatible // indirect + github.com/goccy/go-json v0.10.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/leodido/go-urn v1.2.1 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.6 // indirect + github.com/pkoukk/tiktoken-go v0.1.0 // indirect + github.com/redis/go-redis/v9 v9.0.3 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.9 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/crypto v0.5.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect + google.golang.org/protobuf v1.28.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..920c395 --- /dev/null +++ b/go.sum @@ -0,0 +1,87 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.8.0 h1:ea0Xadu+sHlu7x5O3gKhRpQ1IKiMrSiHttPF0ybECuA= +github.com/bytedance/sonic v1.8.0/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8= +github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= +github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= +github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= +github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= +github.com/pkoukk/tiktoken-go v0.1.0 h1:X1uP3+Nd8C3xe6AIGRWjchrylyaye0FDDTG22cxNQZs= +github.com/pkoukk/tiktoken-go v0.1.0/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis v6.15.9+incompatible h1:F+tnlesQSl3h9V8DdmtcYFdvkHLhbb7AgcLW6UJxnC4= +github.com/redis/go-redis v6.15.9+incompatible/go.mod h1:ic6dLmR0d9rkHSzaa0Ab3QVRZcjopJ9hSSPCrecj/+s= +github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= +github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU= +github.com/ugorji/go/codec v1.2.9/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go new file mode 100644 index 0000000..bdc1a65 --- /dev/null +++ b/main.go @@ -0,0 +1,562 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" + "github.com/redis/go-redis/v9" +) + +type Message struct { + Role string `json:"role,omitempty"` + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` +} + +type ChatRequest struct { + Stream bool `json:"stream,omitempty"` + Model string `json:"model,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt string `json:"prompt,omitempty"` + Input string `json:"input,omitempty"` +} + +type ChatResponse struct { + Id string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Usage Usage `json:"usage,omitempty"` + Choices []Choice `json:"choices,omitempty"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type Choice struct { + Message Message `json:"message,omitempty"` + Delta Delta `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Index int `json:"index,omitempty"` + Text string `json:"text,omitempty"` +} + +type Delta struct { + Content string `json:"content,omitempty"` +} + +type CreditSummary struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` +} + +type ListModelResponse struct { + Object string `json:"object"` + Data []Model `json:"data"` +} + +type Model struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []ModelPermission `json:"permission"` + Root string `json:"root"` + Parent any `json:"parent"` +} + +type ModelPermission struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group any `json:"group"` + IsBlocking bool `json:"is_blocking"` +} + +type UserInfo struct { + UID string `json:"uid"` + SID string `json:"sid"` +} + +type ServerInfo struct { + ServerAddress string `json:"server_address"` + AvailableKey string `json:"available_key"` +} + +type ModelInfo struct { + ModelName string `json:"model_name"` + ModelPrice float64 `json:"model_price"` + ModelPrepayment int `json:"model_prepayment"` +} + +var ( + Redis *redis.Client + RedisAddress = "localhost:6379" +) + +func init() { + //gin.SetMode(gin.ReleaseMode) + if v := os.Getenv("REDIS_ADDRESS"); v != "" { + RedisAddress = v + } + log.Printf("loading redis address: %s", RedisAddress) +} + +// redis 初始化 +func InitRedis() *redis.Client { + rdb := redis.NewClient(&redis.Options{ + Addr: RedisAddress, + Password: "", // no password set + DB: 0, // use default DB + PoolSize: 10, + }) + result := rdb.Ping(context.Background()) + fmt.Println("redis ping:", result.Val()) + if result.Val() != "PONG" { + // 连接有问题 + return nil + } + return rdb +} + +// 计算Messages中的token数量 +func numTokensFromMessages(messages []Message, model string) int { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + panic(err) + } + numTokens := 0 + for _, message := range messages { + numTokens += len(tkm.Encode(message.Content, nil, nil)) + numTokens += 6 + } + numTokens += 3 + return numTokens +} + +// 计算String中的token数量 +func numTokensFromString(msg string, model string) int { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + panic(err) + } + if model == "text-davinci-003" { + return len(tkm.Encode(msg, nil, nil)) + 1 + } else { + return len(tkm.Encode(msg, nil, nil)) + 9 + } +} + +func embeddings(c *gin.Context) { + var prompt_tokens int + var total_tokens int + var chatRequest ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := checkBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + panic(err) + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + var chatResponse ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens) + timestamp := time.Now().Unix() + timestampID := "emb-" + fmt.Sprintf("%d", timestamp) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + +func completions(c *gin.Context) { + var prompt_tokens int + var completion_tokens int + var total_tokens int + var chatRequest ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := checkBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + panic(err) + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + if chatRequest.Stream { + // 流式回应,处理 + proxy.ModifyResponse = func(resp *http.Response) error { + chatRequestId := "" + reqContent := "" + reader := bufio.NewReader(resp.Body) + headers := resp.Header + for k, v := range headers { + c.Writer.Header().Set(k, v[0]) + } + + for { + chunk, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + return err + } + var chatResponse ChatResponse + //去除回应中的data:前缀 + var trimStr = strings.Trim(string(chunk), "data: ") + if trimStr != "\n" { + json.Unmarshal([]byte(trimStr), &chatResponse) + //log.Printf("trimStr:" + trimStr) + if chatResponse.Choices != nil { + reqContent += chatResponse.Choices[0].Delta.Content + chatRequestId = chatResponse.Id + } + + // 写回数据 + _, err = c.Writer.Write([]byte(string(chunk) + "\n")) + if err != nil { + return err + } + c.Writer.(http.Flusher).Flush() + } + } + if chatRequest.Model == "text-davinci-003" { + prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model) + } else { + prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model) + } + completion_tokens = numTokensFromString(reqContent, chatRequest.Model) + total_tokens = prompt_tokens + completion_tokens + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId) + return nil + } + } else { + // 非流式回应,处理 + proxy.ModifyResponse = func(resp *http.Response) error { + var chatResponse ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + completion_tokens = chatResponse.Usage.CompletionTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id) + return nil + } + } + + proxy.ServeHTTP(c.Writer, c.Request) +} + +// model查询列表 +func handleGetModels(c *gin.Context) { + // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3 + //models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} + models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301"} + result := ListModelResponse{ + Object: "list", + } + for _, model := range models { + result.Data = append(result.Data, Model{ + ID: model, + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: []ModelPermission{ + { + ID: "", + Object: "model", + Created: 1679602087, + AllowCreateEngine: true, + AllowSampling: true, + AllowLogprobs: true, + AllowSearchIndices: true, + AllowView: true, + AllowFineTuning: true, + Organization: "*", + Group: nil, + IsBlocking: false, + }, + }, + Root: model, + Parent: nil, + }) + } + c.JSON(200, result) +} + +// 针对接口预检OPTIONS的处理 +func handleOptions(c *gin.Context) { + // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") + c.Status(200) + return +} + +// 余额查询 +func balance(c *gin.Context) { + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + balance, err := queryBlance(key) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + } + var creditSummary CreditSummary + creditSummary.Object = "credit_grant" + creditSummary.TotalGranted = 999999 + creditSummary.TotalUsed = 999999 - balance + creditSummary.TotalRemaining = balance + c.JSON(200, creditSummary) +} + +// key 检测中间件 +func checkKeyMid() gin.HandlerFunc { + return func(c *gin.Context) { + auth := c.Request.Header.Get("Authorization") + if auth == "" { + c.AbortWithStatusJSON(401, gin.H{"code": 40001}) + } else { + key := strings.Trim(auth, "Bearer ") + msg, err := checkKeyAndTimeCount(key) + if err != nil { + c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()}) + } + } + // 执行函数 + c.Next() + } +} + +func main() { + + // 禁用控制台颜色,将日志写入文件时不需要控制台颜色。 + gin.DisableConsoleColor() + + // 记录到文件。 + f, _ := os.Create("gin.log") + //gin.DefaultWriter = io.MultiWriter(f) + // 如果需要同时将日志写入文件和控制台,请使用以下代码。 + gin.DefaultWriter = io.MultiWriter(f, os.Stdout) + + log.SetOutput(gin.DefaultWriter) + + r := gin.Default() + + r.GET("/v1/models", handleGetModels) + r.GET("/proxy/v1/models", handleGetModels) + + r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) + r.GET("/proxy/dashboard/billing/credit_grants", checkKeyMid(), balance) + + r.OPTIONS("/v1/*path", handleOptions) + r.OPTIONS("/proxy/v1/*path", handleOptions) + + r.POST("/v1/chat/completions", checkKeyMid(), completions) + r.POST("/proxy/v1/chat/completions", checkKeyMid(), completions) + + r.POST("/v1/completions", checkKeyMid(), completions) + r.POST("/proxy/v1/completions", checkKeyMid(), completions) + + r.POST("/v1/embeddings", checkKeyMid(), embeddings) + r.POST("/proxy/v1/embeddings", checkKeyMid(), embeddings) + + r.POST("/mock1", mockBalanceInquiry) + r.POST("/mock2", mockBalanceConsumption) + + Redis = InitRedis() + + //添加reids测试数据 + var serverInfo ServerInfo = ServerInfo{ + ServerAddress: "https://gptp.any-door.cn", + AvailableKey: "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,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,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", + } + var serverInfo2 ServerInfo = ServerInfo{ + ServerAddress: "https://azure.any-door.cn", + AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119", + } + serverInfoStr, _ := json.Marshal(&serverInfo) + serverInfoStr2, _ := json.Marshal(&serverInfo2) + Redis.Set(context.Background(), "server:1", serverInfoStr, 0) + Redis.Set(context.Background(), "server:2", serverInfoStr2, 0) + + var modelInfo ModelInfo = ModelInfo{ + ModelName: "gpt-3.5-turbo", + ModelPrice: 0.0001, + ModelPrepayment: 4000, + } + modelInfoStr, _ := json.Marshal(&modelInfo) + var modelInfo2 ModelInfo = ModelInfo{ + ModelName: "text-davinci-003", + ModelPrice: 0.001, + ModelPrepayment: 4000, + } + modelInfoStr2, _ := json.Marshal(&modelInfo2) + var modelInfo3 ModelInfo = ModelInfo{ + ModelName: "text-davinci-003", + ModelPrice: 0.001, + ModelPrepayment: 4000, + } + modelInfoStr3, _ := json.Marshal(&modelInfo3) + Redis.Set(context.Background(), "model:gpt-3.5-turbo", modelInfoStr, 0) + Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0) + Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0) + Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0) + + var userInfo UserInfo = UserInfo{ + UID: "1", + SID: "1", + } + var userInfo2 UserInfo = UserInfo{ + UID: "2", + SID: "2", + } + userInfoStr, _ := json.Marshal(&userInfo) + userInfoStr2, _ := json.Marshal(&userInfo2) + Redis.Set(context.Background(), "user:8aeb3747-715c-48e8-8b80-aec815949f22", userInfoStr, 0) + Redis.Set(context.Background(), "user:key2", userInfoStr2, 0) + + Redis.IncrByFloat(context.Background(), "user:1:balance", 1000).Result() + Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result() + + //r.Run("127.0.0.1:8080") + //docker下使用 + r.Run("0.0.0.0:8080") +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..90d8832 --- /dev/null +++ b/service.go @@ -0,0 +1,125 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "log" + "time" +) + +// 检测key是否存在,是否超出每分钟请求次数 +func checkKeyAndTimeCount(key string) (int, error) { + var timeOut = 60 * time.Second + var timeCount = 30 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + if err != nil { + //用户不存在 + return 401, errors.New("40003") + } + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + if err != nil { + //用户状态异常 + return 401, errors.New("40004") + } + count, err := Redis.Incr(context.Background(), "user:count:"+key).Result() + if err != nil { + return 500, errors.New("系统计数器设置异常") + } + if count == 1 { + _, err := Redis.Expire(context.Background(), "user:count:"+key, timeOut).Result() + if err != nil { + return 500, errors.New("系统计数器异常") + } + } + // 如果请求次数超出限制,则中断请求 + if count > int64(timeCount) { + //您的账户请求次数过多,超过分钟配额 + return 429, errors.New("42901") + } + return 200, nil +} + +func queryBlance(key string) (float64, error) { + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", 0).Result() + if err != nil { + return 0, errors.New("余额计算失败") + } + return balance, nil +} + +// 余额查询 +func checkBlance(key string, model string) (ServerInfo, error) { + var serverInfo ServerInfo + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + //获取服务器信息 + serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + if err != nil { + return serverInfo, errors.New("服务器信息不存在") + } + err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) + if err != nil { + return serverInfo, errors.New("服务器信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return serverInfo, errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return serverInfo, errors.New("模型信息解析失败") + } + //计算余额-先扣除指定金额 + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", -(float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)).Result() + if err != nil { + return serverInfo, errors.New("余额计算失败") + } + if balance < 0 { + Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice).Result() + return serverInfo, errors.New("用户余额不足") + } + + return serverInfo, nil +} + +// 余额消费 +func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + if err != nil { + return "", errors.New("用户信息不存在") + } + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + if err != nil { + return "", errors.New("用户信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return "", errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return "", errors.New("模型信息解析失败") + } + Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(total_tokens)*modelInfo.ModelPrice)).Result() + + // 余额消费日志请求 + result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id) + if err != nil { + return "", err + } + log.Printf("扣费KEY: %s 扣费token数: %d 扣费结果 %s", key, total_tokens, result) + return result, nil +}