package common import ( "api2gpt-mid/model" "fmt" "github.com/pkoukk/tiktoken-go" "strings" ) // 计算Messages中的token数量 func NumTokensFromMessages(messages []model.Message, model string) int { if strings.Contains(model, "gpt-3.5") { model = "gpt-3.5-turbo" } if strings.Contains(model, "gpt-4") { model = "gpt-4" } tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("getEncoding: %v", err) return 0 } numTokens := 0 for _, message := range messages { numTokens += len(tkm.Encode(message.Content, nil, nil)) numTokens += 6 } numTokens += 3 return numTokens } // 计算String中的token数量 func NumTokensFromString(msg string, model string) int { if strings.Contains(model, "gpt-3.5") { model = "gpt-3.5-turbo" } if strings.Contains(model, "gpt-4") { model = "gpt-4" } tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("getEncoding: %v", err) return 0 } if model == "text-davinci-003" { return len(tkm.Encode(msg, nil, nil)) + 1 } else { return len(tkm.Encode(msg, nil, nil)) + 9 } }