Stripped out ui/api
This commit is contained in:
3
src/bot/oai/mod.rs
Normal file
3
src/bot/oai/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod model;
|
||||
|
||||
pub use model::*;
|
||||
160
src/bot/oai/model.rs
Normal file
160
src/bot/oai/model.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde_json::Value;
|
||||
use siren::ServiceError;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum GPTRole {
|
||||
#[serde(rename = "system")]
|
||||
System,
|
||||
#[serde(rename = "user")]
|
||||
User,
|
||||
#[serde(rename = "assistant")]
|
||||
Assistant,
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatCompletionMessage>,
|
||||
/// Value between 0 and 2
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
/// Value between 0 and 1
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<i64>,
|
||||
/// Value between -2.0 and 2.0
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f64>,
|
||||
/// Value between -2.0 and 2.0
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionMessage {
|
||||
pub role: GPTRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub created: i64,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: i64,
|
||||
pub completion_tokens: i64,
|
||||
pub total_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub message: ChatCompletionMessage,
|
||||
pub finish_reason: String,
|
||||
pub index: i64,
|
||||
pub logprobs: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ResponseEvent {
|
||||
ChatCompletionResponse(ChatCompletionResponse),
|
||||
ResponseError(ResponseError),
|
||||
// ChatCompletionResponse {
|
||||
// id: String,
|
||||
// object: String,
|
||||
// system_fingerprint: Option<String>,
|
||||
// created: i64,
|
||||
// model: String,
|
||||
// usage: Usage,
|
||||
// choices: Vec<Choice>,
|
||||
// },
|
||||
// ResponseError {
|
||||
// error: Option<ErrorDetails>,
|
||||
// message: Option<String>,
|
||||
// param: Option<String>,
|
||||
// #[serde(rename = "type")]
|
||||
// error_type: Option<String>,
|
||||
// },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ResponseError {
|
||||
error: Option<ErrorDetails>,
|
||||
message: Option<String>,
|
||||
param: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
error_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct ErrorDetails {
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
pub struct OAI {
|
||||
pub client: reqwest::Client,
|
||||
pub base_url: String,
|
||||
pub service_url: String,
|
||||
pub max_attempts: i64,
|
||||
pub token: String,
|
||||
pub max_tokens: i64,
|
||||
pub default_model: String,
|
||||
pub max_context_questions: i64,
|
||||
}
|
||||
|
||||
impl OAI {
|
||||
pub async fn chat_completion(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, ServiceError> {
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&self.token)
|
||||
.header("Content-Type", "application/json".to_string())
|
||||
.json(&request)
|
||||
.send()
|
||||
.await;
|
||||
match response {
|
||||
Ok(response) => {
|
||||
let value = response.json::<Value>().await?;
|
||||
let event: ResponseEvent = serde_json::from_value::<ResponseEvent>(value)?;
|
||||
match event {
|
||||
ResponseEvent::ChatCompletionResponse(response) => {
|
||||
return Ok(response);
|
||||
},
|
||||
ResponseEvent::ResponseError(error) => {
|
||||
return Err(ServiceError {
|
||||
status: 500,
|
||||
message: format!("Error: {}", error.message.unwrap()),
|
||||
});
|
||||
},
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(ServiceError {
|
||||
status: 500,
|
||||
message: format!("Error: {}", err),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user