From fd8cb3c976a83cb69a834811b8a46166e8366b32 Mon Sep 17 00:00:00 2001 From: lvxiu_ext Date: Sat, 29 Apr 2023 10:11:58 +0800 Subject: [PATCH] init project --- .gitignore | 1 + Dockerfile | 18 ++ README.md | 22 +++ api.go | 102 ++++++++++ go.mod | 37 ++++ go.sum | 87 +++++++++ main.go | 562 +++++++++++++++++++++++++++++++++++++++++++++++++++++ service.go | 125 ++++++++++++ 8 files changed, 954 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 api.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 service.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e996274 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +gin.log \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a4ab424 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +# Build step +FROM golang:1.18 AS builder +ENV GOPROXY=https://goproxy.cn,direct +RUN mkdir -p /build +WORKDIR /build +COPY . . +RUN go build + +# Final step +FROM debian:buster-slim +RUN set -x && apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + ca-certificates && \ + rm -rf /var/lib/apt/lists/* \ + +EXPOSE 8080 +WORKDIR /app +COPY --from=builder /build/main /app/api2u-go +ENTRYPOINT ["/app/api2u-go"] \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..2409e1f --- /dev/null +++ b/README.md @@ -0,0 +1,22 @@ +## docker 打包镜像 +docker stop api2u-go +docker rm api2u-go +docker build -t api2u-go . + +## docker 运行镜像 +docker run -d -p 8081:8080 --name=api2u-go --env REDIS_ADDRESS=172.17.0.1:6379 api2u-go +docker run -p 8081:8080 --name=api2u-go --env REDIS_ADDRESS=172.17.0.1:6379 api2u-go + + + + +## nginx 配置 +proxy_set_header Host $host; +proxy_set_header X-Real-IP $remote_addr; +proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +proxy_cache off; +proxy_cache_bypass $http_pragam; +proxy_cache_revalidate on; +proxy_http_version 1.1; +proxy_buffering off; +proxy_pass http://localhost:8081/; \ No newline at end of file diff --git a/api.go b/api.go new file mode 100644 index 0000000..4d6e567 --- /dev/null +++ b/api.go @@ -0,0 +1,102 @@ +package main + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + + "github.com/gin-gonic/gin" +) + +type BalanceInfo struct { + ServerAddress string `json:"server_address"` + AvailableKey string `json:"available_key"` + UserBalance float64 `json:"user_balance"` + TokenRatio float64 `json:"token_ratio"` +} + +type Consumption struct { + SecretKey string `json:"secretKey"` + Model string `json:"model"` + MsgId string `json:"msgId"` + PromptTokens int `json:"promptTokens"` + CompletionTokens int `json:"completionTokens"` + TotalTokens int `json:"totalTokens"` +} + +func mockBalanceInquiry(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "server_address": "https://gptp.any-door.cn", + "available_key": "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", + "user_balance": 10000, + "token_ratio": 1000, + }) +} + +func mockBalanceConsumption(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": "true", + }) +} + +// 余额查询api调用 +func balanceInquiry(key string, model string) (*BalanceInfo, error) { + url := "http://localhost:8080/mock1?key=" + key + "&model=" + model + req, err := http.NewRequest("POST", url, nil) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + panic(err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var balanceInfo BalanceInfo + if err := json.Unmarshal(body, &balanceInfo); err != nil { + return nil, err + } + + return &balanceInfo, nil +} + +// 余额消费 +func balanceConsumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { + var data = Consumption{ + SecretKey: key, + Model: model, + MsgId: msg_id, + PromptTokens: prompt_tokens, + CompletionTokens: completion_tokens, + TotalTokens: total_tokens, + } + + jsonData, err := json.Marshal(data) + // 构造post请求的body + reqBody := bytes.NewBuffer(jsonData) + + url := "http://121.4.100.155:8080/other/usageRecord" + req2, err := http.NewRequest("POST", url, reqBody) + + // 设置http请求的header + req2.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req2) + if err != nil { + panic(err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return string(body), nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..44d03b7 --- /dev/null +++ b/go.mod @@ -0,0 +1,37 @@ +module main + +go 1.20 + +require ( + github.com/bytedance/sonic v1.8.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/gin-gonic/gin v1.9.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.11.2 // indirect + github.com/go-redis/redis v6.15.9+incompatible // indirect + github.com/goccy/go-json v0.10.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/leodido/go-urn v1.2.1 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.6 // indirect + github.com/pkoukk/tiktoken-go v0.1.0 // indirect + github.com/redis/go-redis/v9 v9.0.3 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.9 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/crypto v0.5.0 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect + google.golang.org/protobuf v1.28.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..920c395 --- /dev/null +++ b/go.sum @@ -0,0 +1,87 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.8.0 h1:ea0Xadu+sHlu7x5O3gKhRpQ1IKiMrSiHttPF0ybECuA= +github.com/bytedance/sonic v1.8.0/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8= +github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= +github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= +github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= +github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= +github.com/pkoukk/tiktoken-go v0.1.0 h1:X1uP3+Nd8C3xe6AIGRWjchrylyaye0FDDTG22cxNQZs= +github.com/pkoukk/tiktoken-go v0.1.0/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis v6.15.9+incompatible h1:F+tnlesQSl3h9V8DdmtcYFdvkHLhbb7AgcLW6UJxnC4= +github.com/redis/go-redis v6.15.9+incompatible/go.mod h1:ic6dLmR0d9rkHSzaa0Ab3QVRZcjopJ9hSSPCrecj/+s= +github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= +github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU= +github.com/ugorji/go/codec v1.2.9/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go new file mode 100644 index 0000000..bdc1a65 --- /dev/null +++ b/main.go @@ -0,0 +1,562 @@ +package main + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" + "github.com/redis/go-redis/v9" +) + +type Message struct { + Role string `json:"role,omitempty"` + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` +} + +type ChatRequest struct { + Stream bool `json:"stream,omitempty"` + Model string `json:"model,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt string `json:"prompt,omitempty"` + Input string `json:"input,omitempty"` +} + +type ChatResponse struct { + Id string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Usage Usage `json:"usage,omitempty"` + Choices []Choice `json:"choices,omitempty"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type Choice struct { + Message Message `json:"message,omitempty"` + Delta Delta `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Index int `json:"index,omitempty"` + Text string `json:"text,omitempty"` +} + +type Delta struct { + Content string `json:"content,omitempty"` +} + +type CreditSummary struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` +} + +type ListModelResponse struct { + Object string `json:"object"` + Data []Model `json:"data"` +} + +type Model struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []ModelPermission `json:"permission"` + Root string `json:"root"` + Parent any `json:"parent"` +} + +type ModelPermission struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group any `json:"group"` + IsBlocking bool `json:"is_blocking"` +} + +type UserInfo struct { + UID string `json:"uid"` + SID string `json:"sid"` +} + +type ServerInfo struct { + ServerAddress string `json:"server_address"` + AvailableKey string `json:"available_key"` +} + +type ModelInfo struct { + ModelName string `json:"model_name"` + ModelPrice float64 `json:"model_price"` + ModelPrepayment int `json:"model_prepayment"` +} + +var ( + Redis *redis.Client + RedisAddress = "localhost:6379" +) + +func init() { + //gin.SetMode(gin.ReleaseMode) + if v := os.Getenv("REDIS_ADDRESS"); v != "" { + RedisAddress = v + } + log.Printf("loading redis address: %s", RedisAddress) +} + +// redis 初始化 +func InitRedis() *redis.Client { + rdb := redis.NewClient(&redis.Options{ + Addr: RedisAddress, + Password: "", // no password set + DB: 0, // use default DB + PoolSize: 10, + }) + result := rdb.Ping(context.Background()) + fmt.Println("redis ping:", result.Val()) + if result.Val() != "PONG" { + // 连接有问题 + return nil + } + return rdb +} + +// 计算Messages中的token数量 +func numTokensFromMessages(messages []Message, model string) int { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + panic(err) + } + numTokens := 0 + for _, message := range messages { + numTokens += len(tkm.Encode(message.Content, nil, nil)) + numTokens += 6 + } + numTokens += 3 + return numTokens +} + +// 计算String中的token数量 +func numTokensFromString(msg string, model string) int { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("getEncoding: %v", err) + panic(err) + } + if model == "text-davinci-003" { + return len(tkm.Encode(msg, nil, nil)) + 1 + } else { + return len(tkm.Encode(msg, nil, nil)) + 9 + } +} + +func embeddings(c *gin.Context) { + var prompt_tokens int + var total_tokens int + var chatRequest ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := checkBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + panic(err) + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + + proxy.ModifyResponse = func(resp *http.Response) error { + var chatResponse ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v total_tokens: %v", prompt_tokens, total_tokens) + timestamp := time.Now().Unix() + timestampID := "emb-" + fmt.Sprintf("%d", timestamp) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, 0, total_tokens, timestampID) + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) + +} + +func completions(c *gin.Context) { + var prompt_tokens int + var completion_tokens int + var total_tokens int + var chatRequest ChatRequest + if err := c.ShouldBindJSON(&chatRequest); err != nil { + c.AbortWithStatusJSON(400, gin.H{"error": err.Error()}) + return + } + + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + //根据KEY调用用户余额接口,判断是否有足够的余额, 后期可考虑判断max_tokens参数来调整 + serverInfo, err := checkBlance(key, chatRequest.Model) + if err != nil { + c.AbortWithStatusJSON(403, gin.H{"error": err.Error()}) + log.Printf("请求出错 KEY: %v Model: %v ERROR: %v", key, chatRequest.Model, err) + return + } + + log.Printf("请求的KEY: %v Model: %v", key, chatRequest.Model) + + remote, err := url.Parse(serverInfo.ServerAddress) + if err != nil { + panic(err) + } + + proxy := httputil.NewSingleHostReverseProxy(remote) + newReqBody, err := json.Marshal(chatRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + proxy.Director = func(req *http.Request) { + req.Header = c.Request.Header + req.Host = remote.Host + req.URL.Scheme = remote.Scheme + req.URL.Host = remote.Host + req.URL.Path = c.Request.URL.Path + req.ContentLength = int64(len(newReqBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "") + serverKey := serverInfo.AvailableKey + keyList := strings.Split(serverInfo.AvailableKey, ",") + if len(keyList) > 1 { + // 随机数种子 + rand.Seed(time.Now().UnixNano()) + // 从数组中随机选择一个元素 + serverKey = keyList[rand.Intn(len(keyList))] + } + req.Header.Set("Authorization", "Bearer "+serverKey) + req.Body = ioutil.NopCloser(bytes.NewReader(newReqBody)) + } + sss, err := json.Marshal(chatRequest) + log.Printf("开始处理返回逻辑 %d", string(sss)) + if chatRequest.Stream { + // 流式回应,处理 + proxy.ModifyResponse = func(resp *http.Response) error { + chatRequestId := "" + reqContent := "" + reader := bufio.NewReader(resp.Body) + headers := resp.Header + for k, v := range headers { + c.Writer.Header().Set(k, v[0]) + } + + for { + chunk, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + return err + } + var chatResponse ChatResponse + //去除回应中的data:前缀 + var trimStr = strings.Trim(string(chunk), "data: ") + if trimStr != "\n" { + json.Unmarshal([]byte(trimStr), &chatResponse) + //log.Printf("trimStr:" + trimStr) + if chatResponse.Choices != nil { + reqContent += chatResponse.Choices[0].Delta.Content + chatRequestId = chatResponse.Id + } + + // 写回数据 + _, err = c.Writer.Write([]byte(string(chunk) + "\n")) + if err != nil { + return err + } + c.Writer.(http.Flusher).Flush() + } + } + if chatRequest.Model == "text-davinci-003" { + prompt_tokens = numTokensFromString(chatRequest.Prompt, chatRequest.Model) + } else { + prompt_tokens = numTokensFromMessages(chatRequest.Messages, chatRequest.Model) + } + completion_tokens = numTokensFromString(reqContent, chatRequest.Model) + total_tokens = prompt_tokens + completion_tokens + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatRequestId) + return nil + } + } else { + // 非流式回应,处理 + proxy.ModifyResponse = func(resp *http.Response) error { + var chatResponse ChatResponse + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + json.Unmarshal(body, &chatResponse) + prompt_tokens = chatResponse.Usage.PromptTokens + completion_tokens = chatResponse.Usage.CompletionTokens + total_tokens = chatResponse.Usage.TotalTokens + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + log.Printf("prompt_tokens: %v completion_tokens: %v total_tokens: %v", prompt_tokens, completion_tokens, total_tokens) + //消费余额 + consumption(key, chatRequest.Model, prompt_tokens, completion_tokens, total_tokens, chatResponse.Id) + return nil + } + } + + proxy.ServeHTTP(c.Writer, c.Request) +} + +// model查询列表 +func handleGetModels(c *gin.Context) { + // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/3 + //models := []string{"gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "text-davinci-003", "text-embedding-ada-002"} + models := []string{"gpt-3.5-turbo", "gpt-3.5-turbo-0301"} + result := ListModelResponse{ + Object: "list", + } + for _, model := range models { + result.Data = append(result.Data, Model{ + ID: model, + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: []ModelPermission{ + { + ID: "", + Object: "model", + Created: 1679602087, + AllowCreateEngine: true, + AllowSampling: true, + AllowLogprobs: true, + AllowSearchIndices: true, + AllowView: true, + AllowFineTuning: true, + Organization: "*", + Group: nil, + IsBlocking: false, + }, + }, + Root: model, + Parent: nil, + }) + } + c.JSON(200, result) +} + +// 针对接口预检OPTIONS的处理 +func handleOptions(c *gin.Context) { + // BUGFIX: fix options request, see https://github.com/diemus/azure-openai-proxy/issues/1 + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization") + c.Status(200) + return +} + +// 余额查询 +func balance(c *gin.Context) { + auth := c.Request.Header.Get("Authorization") + key := strings.Trim(auth, "Bearer ") + balance, err := queryBlance(key) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + } + var creditSummary CreditSummary + creditSummary.Object = "credit_grant" + creditSummary.TotalGranted = 999999 + creditSummary.TotalUsed = 999999 - balance + creditSummary.TotalRemaining = balance + c.JSON(200, creditSummary) +} + +// key 检测中间件 +func checkKeyMid() gin.HandlerFunc { + return func(c *gin.Context) { + auth := c.Request.Header.Get("Authorization") + if auth == "" { + c.AbortWithStatusJSON(401, gin.H{"code": 40001}) + } else { + key := strings.Trim(auth, "Bearer ") + msg, err := checkKeyAndTimeCount(key) + if err != nil { + c.AbortWithStatusJSON(msg, gin.H{"code": err.Error()}) + } + } + // 执行函数 + c.Next() + } +} + +func main() { + + // 禁用控制台颜色,将日志写入文件时不需要控制台颜色。 + gin.DisableConsoleColor() + + // 记录到文件。 + f, _ := os.Create("gin.log") + //gin.DefaultWriter = io.MultiWriter(f) + // 如果需要同时将日志写入文件和控制台,请使用以下代码。 + gin.DefaultWriter = io.MultiWriter(f, os.Stdout) + + log.SetOutput(gin.DefaultWriter) + + r := gin.Default() + + r.GET("/v1/models", handleGetModels) + r.GET("/proxy/v1/models", handleGetModels) + + r.GET("/dashboard/billing/credit_grants", checkKeyMid(), balance) + r.GET("/proxy/dashboard/billing/credit_grants", checkKeyMid(), balance) + + r.OPTIONS("/v1/*path", handleOptions) + r.OPTIONS("/proxy/v1/*path", handleOptions) + + r.POST("/v1/chat/completions", checkKeyMid(), completions) + r.POST("/proxy/v1/chat/completions", checkKeyMid(), completions) + + r.POST("/v1/completions", checkKeyMid(), completions) + r.POST("/proxy/v1/completions", checkKeyMid(), completions) + + r.POST("/v1/embeddings", checkKeyMid(), embeddings) + r.POST("/proxy/v1/embeddings", checkKeyMid(), embeddings) + + r.POST("/mock1", mockBalanceInquiry) + r.POST("/mock2", mockBalanceConsumption) + + Redis = InitRedis() + + //添加reids测试数据 + var serverInfo ServerInfo = ServerInfo{ + ServerAddress: "https://gptp.any-door.cn", + AvailableKey: "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,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,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", + } + var serverInfo2 ServerInfo = ServerInfo{ + ServerAddress: "https://azure.any-door.cn", + AvailableKey: "6c4d2c65970b40e482e7cd27adb0d119", + } + serverInfoStr, _ := json.Marshal(&serverInfo) + serverInfoStr2, _ := json.Marshal(&serverInfo2) + Redis.Set(context.Background(), "server:1", serverInfoStr, 0) + Redis.Set(context.Background(), "server:2", serverInfoStr2, 0) + + var modelInfo ModelInfo = ModelInfo{ + ModelName: "gpt-3.5-turbo", + ModelPrice: 0.0001, + ModelPrepayment: 4000, + } + modelInfoStr, _ := json.Marshal(&modelInfo) + var modelInfo2 ModelInfo = ModelInfo{ + ModelName: "text-davinci-003", + ModelPrice: 0.001, + ModelPrepayment: 4000, + } + modelInfoStr2, _ := json.Marshal(&modelInfo2) + var modelInfo3 ModelInfo = ModelInfo{ + ModelName: "text-davinci-003", + ModelPrice: 0.001, + ModelPrepayment: 4000, + } + modelInfoStr3, _ := json.Marshal(&modelInfo3) + Redis.Set(context.Background(), "model:gpt-3.5-turbo", modelInfoStr, 0) + Redis.Set(context.Background(), "model:gpt-3.5-turbo-0301", modelInfoStr, 0) + Redis.Set(context.Background(), "model:text-davinci-003", modelInfoStr2, 0) + Redis.Set(context.Background(), "model:text-embedding-ada-002", modelInfoStr3, 0) + + var userInfo UserInfo = UserInfo{ + UID: "1", + SID: "1", + } + var userInfo2 UserInfo = UserInfo{ + UID: "2", + SID: "2", + } + userInfoStr, _ := json.Marshal(&userInfo) + userInfoStr2, _ := json.Marshal(&userInfo2) + Redis.Set(context.Background(), "user:8aeb3747-715c-48e8-8b80-aec815949f22", userInfoStr, 0) + Redis.Set(context.Background(), "user:key2", userInfoStr2, 0) + + Redis.IncrByFloat(context.Background(), "user:1:balance", 1000).Result() + Redis.IncrByFloat(context.Background(), "user:2:balance", 1000).Result() + + //r.Run("127.0.0.1:8080") + //docker下使用 + r.Run("0.0.0.0:8080") +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..90d8832 --- /dev/null +++ b/service.go @@ -0,0 +1,125 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "log" + "time" +) + +// 检测key是否存在,是否超出每分钟请求次数 +func checkKeyAndTimeCount(key string) (int, error) { + var timeOut = 60 * time.Second + var timeCount = 30 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + if err != nil { + //用户不存在 + return 401, errors.New("40003") + } + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + if err != nil { + //用户状态异常 + return 401, errors.New("40004") + } + count, err := Redis.Incr(context.Background(), "user:count:"+key).Result() + if err != nil { + return 500, errors.New("系统计数器设置异常") + } + if count == 1 { + _, err := Redis.Expire(context.Background(), "user:count:"+key, timeOut).Result() + if err != nil { + return 500, errors.New("系统计数器异常") + } + } + // 如果请求次数超出限制,则中断请求 + if count > int64(timeCount) { + //您的账户请求次数过多,超过分钟配额 + return 429, errors.New("42901") + } + return 200, nil +} + +func queryBlance(key string) (float64, error) { + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", 0).Result() + if err != nil { + return 0, errors.New("余额计算失败") + } + return balance, nil +} + +// 余额查询 +func checkBlance(key string, model string) (ServerInfo, error) { + var serverInfo ServerInfo + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + //获取服务器信息 + serverInfoStr, err := Redis.Get(context.Background(), "server:"+userInfo.SID).Result() + if err != nil { + return serverInfo, errors.New("服务器信息不存在") + } + err = json.Unmarshal([]byte(serverInfoStr), &serverInfo) + if err != nil { + return serverInfo, errors.New("服务器信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return serverInfo, errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return serverInfo, errors.New("模型信息解析失败") + } + //计算余额-先扣除指定金额 + balance, err := Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", -(float64(modelInfo.ModelPrepayment) * modelInfo.ModelPrice)).Result() + if err != nil { + return serverInfo, errors.New("余额计算失败") + } + if balance < 0 { + Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice).Result() + return serverInfo, errors.New("用户余额不足") + } + + return serverInfo, nil +} + +// 余额消费 +func consumption(key string, model string, prompt_tokens int, completion_tokens int, total_tokens int, msg_id string) (string, error) { + //获取用户信息 + userInfoStr, err := Redis.Get(context.Background(), "user:"+key).Result() + if err != nil { + return "", errors.New("用户信息不存在") + } + var userInfo UserInfo + err = json.Unmarshal([]byte(userInfoStr), &userInfo) + if err != nil { + return "", errors.New("用户信息解析失败") + } + //获取模型价格 + modelPriceStr, err := Redis.Get(context.Background(), "model:"+model).Result() + if err != nil { + return "", errors.New("模型信息不存在") + } + var modelInfo ModelInfo + err = json.Unmarshal([]byte(modelPriceStr), &modelInfo) + if err != nil { + return "", errors.New("模型信息解析失败") + } + Redis.IncrByFloat(context.Background(), "user:"+userInfo.SID+":balance", float64(modelInfo.ModelPrepayment)*modelInfo.ModelPrice-(float64(total_tokens)*modelInfo.ModelPrice)).Result() + + // 余额消费日志请求 + result, err := balanceConsumption(key, model, prompt_tokens, completion_tokens, total_tokens, msg_id) + if err != nil { + return "", err + } + log.Printf("扣费KEY: %s 扣费token数: %d 扣费结果 %s", key, total_tokens, result) + return result, nil +}