gglib_core/domain/inference.rs
1//! Inference configuration types.
2//!
3//! Defines shared types for configuring LLM inference parameters
4//! (temperature, `top_p`, `top_k`, `max_tokens`, `repeat_penalty`).
5//!
6//! This module provides the core `InferenceConfig` type that is reused across:
7//! - Per-model defaults (`Model.inference_defaults`)
8//! - Global settings (`Settings.inference_defaults`)
9//! - Request-level overrides (flattened in `ChatProxyRequest`)
10
11use serde::{Deserialize, Serialize};
12
13/// Inference parameters for LLM sampling.
14///
15/// All fields are optional to support partial configuration and fallback chains.
16/// Intended to be shared across model defaults, global settings, and request overrides.
17///
18/// # Hierarchy Resolution
19///
20/// When making an inference request, parameters are resolved in this order:
21/// 1. Request-level override (user specified for this request)
22/// 2. Per-model defaults (stored in `Model.inference_defaults`)
23/// 3. Global settings (stored in `Settings.inference_defaults`)
24/// 4. Hardcoded fallback (e.g., temperature = 0.7)
25///
26/// # Examples
27///
28/// ```rust
29/// use gglib_core::domain::InferenceConfig;
30///
31/// // Conservative settings for code generation
32/// let code_gen = InferenceConfig {
33/// temperature: Some(0.2),
34/// top_p: Some(0.9),
35/// top_k: Some(40),
36/// max_tokens: Some(2048),
37/// repeat_penalty: Some(1.1),
38/// };
39///
40/// // Creative writing settings
41/// let creative = InferenceConfig {
42/// temperature: Some(1.2),
43/// top_p: Some(0.95),
44/// ..Default::default()
45/// };
46/// ```
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
48#[serde(rename_all = "camelCase")]
49pub struct InferenceConfig {
50 /// Sampling temperature (0.0 - 2.0).
51 ///
52 /// Controls randomness in token selection:
53 /// - Lower values (0.1-0.5): More deterministic, focused
54 /// - Medium values (0.7-1.0): Balanced creativity
55 /// - Higher values (1.1-2.0): More random, creative
56 pub temperature: Option<f32>,
57
58 /// Nucleus sampling threshold (0.0 - 1.0).
59 ///
60 /// Considers only the top tokens whose cumulative probability exceeds this threshold.
61 /// Common values: 0.9 (default), 0.95 (more diverse)
62 pub top_p: Option<f32>,
63
64 /// Top-K sampling limit.
65 ///
66 /// Considers only the K most likely next tokens.
67 /// Common values: 40 (default), 10 (focused), 100 (diverse)
68 pub top_k: Option<i32>,
69
70 /// Maximum tokens to generate in response.
71 ///
72 /// Hard limit on response length. Does not include input tokens.
73 pub max_tokens: Option<u32>,
74
75 /// Repetition penalty (> 0.0, typically 1.0 - 1.3).
76 ///
77 /// Penalizes repeated tokens to reduce repetitive output.
78 /// - 1.0: No penalty (default)
79 /// - 1.1-1.3: Moderate penalty
80 /// - > 1.3: Strong penalty (may hurt coherence)
81 pub repeat_penalty: Option<f32>,
82}
83
84impl InferenceConfig {
85 /// Merge another config into this one, preferring values from `other`.
86 ///
87 /// For each field, if `other` has Some(value), use it; otherwise keep self's value.
88 /// This is useful for applying fallback chains.
89 ///
90 /// # Example
91 ///
92 /// ```rust
93 /// use gglib_core::domain::InferenceConfig;
94 ///
95 /// let mut request = InferenceConfig {
96 /// temperature: Some(0.8),
97 /// ..Default::default()
98 /// };
99 ///
100 /// let model_defaults = InferenceConfig {
101 /// temperature: Some(0.5),
102 /// top_p: Some(0.9),
103 /// ..Default::default()
104 /// };
105 ///
106 /// request.merge_with(&model_defaults);
107 /// assert_eq!(request.temperature, Some(0.8)); // Request value wins
108 /// assert_eq!(request.top_p, Some(0.9)); // Fallback to model default
109 /// ```
110 pub const fn merge_with(&mut self, other: &Self) {
111 if self.temperature.is_none() {
112 self.temperature = other.temperature;
113 }
114 if self.top_p.is_none() {
115 self.top_p = other.top_p;
116 }
117 if self.top_k.is_none() {
118 self.top_k = other.top_k;
119 }
120 if self.max_tokens.is_none() {
121 self.max_tokens = other.max_tokens;
122 }
123 if self.repeat_penalty.is_none() {
124 self.repeat_penalty = other.repeat_penalty;
125 }
126 }
127
128 /// Create a new config with all fields set to sensible defaults.
129 ///
130 /// These are the hardcoded fallback values used when no other
131 /// defaults are configured.
132 #[must_use]
133 pub const fn with_hardcoded_defaults() -> Self {
134 Self {
135 temperature: Some(0.7),
136 top_p: Some(0.95),
137 top_k: Some(40),
138 max_tokens: Some(2048),
139 repeat_penalty: Some(1.0),
140 }
141 }
142
143 /// Convert inference config to llama CLI arguments.
144 ///
145 /// Returns a vector of argument strings suitable for passing to llama-cli or llama-server.
146 /// Uses the same flag names as llama.cpp: `--temp`, `--top-p`, `--top-k`, `-n`, `--repeat-penalty`.
147 ///
148 /// This is the single source of truth for CLI flag conversion, used by:
149 /// - `LlamaCommandBuilder` (for CLI commands)
150 /// - GUI server startup (via `ServerConfig.extra_args`)
151 ///
152 /// # Example
153 ///
154 /// ```rust
155 /// use gglib_core::domain::InferenceConfig;
156 ///
157 /// let config = InferenceConfig {
158 /// temperature: Some(0.8),
159 /// top_p: Some(0.9),
160 /// top_k: None,
161 /// max_tokens: Some(1024),
162 /// repeat_penalty: None,
163 /// };
164 ///
165 /// let args = config.to_cli_args();
166 /// assert_eq!(args, vec!["--temp", "0.8", "--top-p", "0.9", "-n", "1024"]);
167 /// ```
168 #[must_use]
169 pub fn to_cli_args(&self) -> Vec<String> {
170 let mut args = Vec::new();
171
172 if let Some(temp) = self.temperature {
173 args.push("--temp".to_string());
174 args.push(temp.to_string());
175 }
176 if let Some(top_p) = self.top_p {
177 args.push("--top-p".to_string());
178 args.push(top_p.to_string());
179 }
180 if let Some(top_k) = self.top_k {
181 args.push("--top-k".to_string());
182 args.push(top_k.to_string());
183 }
184 if let Some(max_tokens) = self.max_tokens {
185 args.push("-n".to_string());
186 args.push(max_tokens.to_string());
187 }
188 if let Some(repeat_penalty) = self.repeat_penalty {
189 args.push("--repeat-penalty".to_string());
190 args.push(repeat_penalty.to_string());
191 }
192
193 args
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_default_is_all_none() {
203 let config = InferenceConfig::default();
204 assert!(config.temperature.is_none());
205 assert!(config.top_p.is_none());
206 assert!(config.top_k.is_none());
207 assert!(config.max_tokens.is_none());
208 assert!(config.repeat_penalty.is_none());
209 }
210
211 #[test]
212 fn test_merge_with_prefers_self() {
213 let mut request = InferenceConfig {
214 temperature: Some(0.8),
215 top_p: None,
216 ..Default::default()
217 };
218
219 let model_defaults = InferenceConfig {
220 temperature: Some(0.5),
221 top_p: Some(0.9),
222 top_k: Some(50),
223 ..Default::default()
224 };
225
226 request.merge_with(&model_defaults);
227
228 assert_eq!(request.temperature, Some(0.8)); // Request wins
229 assert_eq!(request.top_p, Some(0.9)); // Fallback to model
230 assert_eq!(request.top_k, Some(50)); // Fallback to model
231 assert!(request.max_tokens.is_none()); // Still None
232 }
233
234 #[test]
235 fn test_hardcoded_defaults() {
236 let config = InferenceConfig::with_hardcoded_defaults();
237 assert_eq!(config.temperature, Some(0.7));
238 assert_eq!(config.top_p, Some(0.95));
239 assert_eq!(config.top_k, Some(40));
240 assert_eq!(config.max_tokens, Some(2048));
241 assert_eq!(config.repeat_penalty, Some(1.0));
242 }
243
244 #[test]
245 fn test_serialization() {
246 let config = InferenceConfig {
247 temperature: Some(0.7),
248 top_p: Some(0.9),
249 top_k: None,
250 max_tokens: Some(1024),
251 repeat_penalty: None,
252 };
253
254 let json = serde_json::to_string(&config).unwrap();
255 let deserialized: InferenceConfig = serde_json::from_str(&json).unwrap();
256
257 assert_eq!(config, deserialized);
258 }
259}