gglib_core/domain/
inference.rs

1//! Inference configuration types.
2//!
3//! Defines shared types for configuring LLM inference parameters
4//! (temperature, `top_p`, `top_k`, `max_tokens`, `repeat_penalty`).
5//!
6//! This module provides the core `InferenceConfig` type that is reused across:
7//! - Per-model defaults (`Model.inference_defaults`)
8//! - Global settings (`Settings.inference_defaults`)
9//! - Request-level overrides (flattened in `ChatProxyRequest`)
10
11use serde::{Deserialize, Serialize};
12
13/// Inference parameters for LLM sampling.
14///
15/// All fields are optional to support partial configuration and fallback chains.
16/// Intended to be shared across model defaults, global settings, and request overrides.
17///
18/// # Hierarchy Resolution
19///
20/// When making an inference request, parameters are resolved in this order:
21/// 1. Request-level override (user specified for this request)
22/// 2. Per-model defaults (stored in `Model.inference_defaults`)
23/// 3. Global settings (stored in `Settings.inference_defaults`)
24/// 4. Hardcoded fallback (e.g., temperature = 0.7)
25///
26/// # Examples
27///
28/// ```rust
29/// use gglib_core::domain::InferenceConfig;
30///
31/// // Conservative settings for code generation
32/// let code_gen = InferenceConfig {
33///     temperature: Some(0.2),
34///     top_p: Some(0.9),
35///     top_k: Some(40),
36///     max_tokens: Some(2048),
37///     repeat_penalty: Some(1.1),
38/// };
39///
40/// // Creative writing settings
41/// let creative = InferenceConfig {
42///     temperature: Some(1.2),
43///     top_p: Some(0.95),
44///     ..Default::default()
45/// };
46/// ```
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
48#[serde(rename_all = "camelCase")]
49pub struct InferenceConfig {
50    /// Sampling temperature (0.0 - 2.0).
51    ///
52    /// Controls randomness in token selection:
53    /// - Lower values (0.1-0.5): More deterministic, focused
54    /// - Medium values (0.7-1.0): Balanced creativity
55    /// - Higher values (1.1-2.0): More random, creative
56    pub temperature: Option<f32>,
57
58    /// Nucleus sampling threshold (0.0 - 1.0).
59    ///
60    /// Considers only the top tokens whose cumulative probability exceeds this threshold.
61    /// Common values: 0.9 (default), 0.95 (more diverse)
62    pub top_p: Option<f32>,
63
64    /// Top-K sampling limit.
65    ///
66    /// Considers only the K most likely next tokens.
67    /// Common values: 40 (default), 10 (focused), 100 (diverse)
68    pub top_k: Option<i32>,
69
70    /// Maximum tokens to generate in response.
71    ///
72    /// Hard limit on response length. Does not include input tokens.
73    pub max_tokens: Option<u32>,
74
75    /// Repetition penalty (> 0.0, typically 1.0 - 1.3).
76    ///
77    /// Penalizes repeated tokens to reduce repetitive output.
78    /// - 1.0: No penalty (default)
79    /// - 1.1-1.3: Moderate penalty
80    /// - > 1.3: Strong penalty (may hurt coherence)
81    pub repeat_penalty: Option<f32>,
82}
83
84impl InferenceConfig {
85    /// Merge another config into this one, preferring values from `other`.
86    ///
87    /// For each field, if `other` has Some(value), use it; otherwise keep self's value.
88    /// This is useful for applying fallback chains.
89    ///
90    /// # Example
91    ///
92    /// ```rust
93    /// use gglib_core::domain::InferenceConfig;
94    ///
95    /// let mut request = InferenceConfig {
96    ///     temperature: Some(0.8),
97    ///     ..Default::default()
98    /// };
99    ///
100    /// let model_defaults = InferenceConfig {
101    ///     temperature: Some(0.5),
102    ///     top_p: Some(0.9),
103    ///     ..Default::default()
104    /// };
105    ///
106    /// request.merge_with(&model_defaults);
107    /// assert_eq!(request.temperature, Some(0.8)); // Request value wins
108    /// assert_eq!(request.top_p, Some(0.9));      // Fallback to model default
109    /// ```
110    pub const fn merge_with(&mut self, other: &Self) {
111        if self.temperature.is_none() {
112            self.temperature = other.temperature;
113        }
114        if self.top_p.is_none() {
115            self.top_p = other.top_p;
116        }
117        if self.top_k.is_none() {
118            self.top_k = other.top_k;
119        }
120        if self.max_tokens.is_none() {
121            self.max_tokens = other.max_tokens;
122        }
123        if self.repeat_penalty.is_none() {
124            self.repeat_penalty = other.repeat_penalty;
125        }
126    }
127
128    /// Create a new config with all fields set to sensible defaults.
129    ///
130    /// These are the hardcoded fallback values used when no other
131    /// defaults are configured.
132    #[must_use]
133    pub const fn with_hardcoded_defaults() -> Self {
134        Self {
135            temperature: Some(0.7),
136            top_p: Some(0.95),
137            top_k: Some(40),
138            max_tokens: Some(2048),
139            repeat_penalty: Some(1.0),
140        }
141    }
142
143    /// Convert inference config to llama CLI arguments.
144    ///
145    /// Returns a vector of argument strings suitable for passing to llama-cli or llama-server.
146    /// Uses the same flag names as llama.cpp: `--temp`, `--top-p`, `--top-k`, `-n`, `--repeat-penalty`.
147    ///
148    /// This is the single source of truth for CLI flag conversion, used by:
149    /// - `LlamaCommandBuilder` (for CLI commands)
150    /// - GUI server startup (via `ServerConfig.extra_args`)
151    ///
152    /// # Example
153    ///
154    /// ```rust
155    /// use gglib_core::domain::InferenceConfig;
156    ///
157    /// let config = InferenceConfig {
158    ///     temperature: Some(0.8),
159    ///     top_p: Some(0.9),
160    ///     top_k: None,
161    ///     max_tokens: Some(1024),
162    ///     repeat_penalty: None,
163    /// };
164    ///
165    /// let args = config.to_cli_args();
166    /// assert_eq!(args, vec!["--temp", "0.8", "--top-p", "0.9", "-n", "1024"]);
167    /// ```
168    #[must_use]
169    pub fn to_cli_args(&self) -> Vec<String> {
170        let mut args = Vec::new();
171
172        if let Some(temp) = self.temperature {
173            args.push("--temp".to_string());
174            args.push(temp.to_string());
175        }
176        if let Some(top_p) = self.top_p {
177            args.push("--top-p".to_string());
178            args.push(top_p.to_string());
179        }
180        if let Some(top_k) = self.top_k {
181            args.push("--top-k".to_string());
182            args.push(top_k.to_string());
183        }
184        if let Some(max_tokens) = self.max_tokens {
185            args.push("-n".to_string());
186            args.push(max_tokens.to_string());
187        }
188        if let Some(repeat_penalty) = self.repeat_penalty {
189            args.push("--repeat-penalty".to_string());
190            args.push(repeat_penalty.to_string());
191        }
192
193        args
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_default_is_all_none() {
203        let config = InferenceConfig::default();
204        assert!(config.temperature.is_none());
205        assert!(config.top_p.is_none());
206        assert!(config.top_k.is_none());
207        assert!(config.max_tokens.is_none());
208        assert!(config.repeat_penalty.is_none());
209    }
210
211    #[test]
212    fn test_merge_with_prefers_self() {
213        let mut request = InferenceConfig {
214            temperature: Some(0.8),
215            top_p: None,
216            ..Default::default()
217        };
218
219        let model_defaults = InferenceConfig {
220            temperature: Some(0.5),
221            top_p: Some(0.9),
222            top_k: Some(50),
223            ..Default::default()
224        };
225
226        request.merge_with(&model_defaults);
227
228        assert_eq!(request.temperature, Some(0.8)); // Request wins
229        assert_eq!(request.top_p, Some(0.9)); // Fallback to model
230        assert_eq!(request.top_k, Some(50)); // Fallback to model
231        assert!(request.max_tokens.is_none()); // Still None
232    }
233
234    #[test]
235    fn test_hardcoded_defaults() {
236        let config = InferenceConfig::with_hardcoded_defaults();
237        assert_eq!(config.temperature, Some(0.7));
238        assert_eq!(config.top_p, Some(0.95));
239        assert_eq!(config.top_k, Some(40));
240        assert_eq!(config.max_tokens, Some(2048));
241        assert_eq!(config.repeat_penalty, Some(1.0));
242    }
243
244    #[test]
245    fn test_serialization() {
246        let config = InferenceConfig {
247            temperature: Some(0.7),
248            top_p: Some(0.9),
249            top_k: None,
250            max_tokens: Some(1024),
251            repeat_penalty: None,
252        };
253
254        let json = serde_json::to_string(&config).unwrap();
255        let deserialized: InferenceConfig = serde_json::from_str(&json).unwrap();
256
257        assert_eq!(config, deserialized);
258    }
259}