diff --git a/uefi-test-runner/src/proto/shell.rs b/uefi-test-runner/src/proto/shell.rs index 9ce449007..f4aa88c50 100644 --- a/uefi-test-runner/src/proto/shell.rs +++ b/uefi-test-runner/src/proto/shell.rs @@ -100,6 +100,48 @@ pub fn test_current_dir(shell: &ScopedProtocol) { assert_eq!(cur_fs_str, expected_fs_str); } +/// Test `var()`, `vars()`, and `set_var()` +pub fn test_var(shell: &ScopedProtocol) { + /* Test retrieving list of environment variable names */ + let mut cur_env_vec = shell.vars(); + assert_eq!(cur_env_vec.next().unwrap().0, cstr16!("path")); + // check pre-defined shell variables; see UEFI Shell spec + assert_eq!(cur_env_vec.next().unwrap().0, cstr16!("nonesting")); + let cur_env_vec = shell.vars(); + let default_len = cur_env_vec.count(); + + /* Test setting and getting a specific environment variable */ + let test_var = cstr16!("test_var"); + let test_val = cstr16!("test_val"); + + let found_var = shell.vars().any(|(env_var, _)| env_var == test_var); + assert!(!found_var); + assert!(shell.var(test_var).is_none()); + + let status = shell.set_var(test_var, test_val, false); + assert!(status.is_ok()); + let cur_env_str = shell + .var(test_var) + .expect("Could not get environment variable"); + assert_eq!(cur_env_str, test_val); + + let found_var = shell.vars().any(|(env_var, _)| env_var == test_var); + assert!(found_var); + let cur_env_vec = shell.vars(); + assert_eq!(cur_env_vec.count(), default_len + 1); + + /* Test deleting environment variable */ + let test_val = cstr16!(""); + let status = shell.set_var(test_var, test_val, false); + assert!(status.is_ok()); + assert!(shell.var(test_var).is_none()); + + let found_var = shell.vars().any(|(env_var, _)| env_var == test_var); + assert!(!found_var); + let cur_env_vec = shell.vars(); + assert_eq!(cur_env_vec.count(), default_len); +} + pub fn test() { info!("Running shell protocol tests"); @@ -109,4 +151,5 @@ pub fn test() { boot::open_protocol_exclusive::(handle).expect("Failed to open Shell protocol"); test_current_dir(&shell); + test_var(&shell); } diff --git a/uefi/src/proto/mod.rs b/uefi/src/proto/mod.rs index 4451aa113..a1b5f6495 100644 --- a/uefi/src/proto/mod.rs +++ b/uefi/src/proto/mod.rs @@ -51,6 +51,7 @@ pub mod rng; #[cfg(feature = "alloc")] pub mod scsi; pub mod security; +#[cfg(feature = "alloc")] pub mod shell; pub mod shell_params; pub mod shim; diff --git a/uefi/src/proto/shell/mod.rs b/uefi/src/proto/shell/mod.rs index 41940f340..3ad5098dc 100644 --- a/uefi/src/proto/shell/mod.rs +++ b/uefi/src/proto/shell/mod.rs @@ -4,6 +4,8 @@ use crate::proto::unsafe_protocol; use crate::{CStr16, Char16, Error, Result, Status, StatusExt}; + +use core::marker::PhantomData; use core::ptr; use uefi_raw::protocol::shell::ShellProtocol; @@ -13,6 +15,45 @@ use uefi_raw::protocol::shell::ShellProtocol; #[unsafe_protocol(ShellProtocol::GUID)] pub struct Shell(ShellProtocol); +/// Trait for implementing the var function +pub trait ShellVarProvider { + /// Gets the value of the specified environment variable + fn var(&self, name: &CStr16) -> Option<&CStr16>; +} + +/// Iterator over the names of environmental variables obtained from the Shell protocol. +#[derive(Debug)] +pub struct Vars<'a, T: ShellVarProvider> { + /// Char16 containing names of environment variables + names: *const Char16, + /// Reference to Shell Protocol + protocol: *const T, + /// Marker to attach a lifetime to `Vars` + _marker: PhantomData<&'a CStr16>, +} + +impl<'a, T: ShellVarProvider + 'a> Iterator for Vars<'a, T> { + type Item = (&'a CStr16, Option<&'a CStr16>); + // We iterate a list of NUL terminated CStr16s. + // The list is terminated with a double NUL. + fn next(&mut self) -> Option { + let s = unsafe { CStr16::from_ptr(self.names) }; + if s.is_empty() { + None + } else { + self.names = unsafe { self.names.add(s.num_chars() + 1) }; + Some((s, unsafe { self.protocol.as_ref().unwrap().var(s) })) + } + } +} + +impl ShellVarProvider for Shell { + /// Gets the value of the specified environment variable + fn var(&self, name: &CStr16) -> Option<&CStr16> { + self.var(name) + } +} + impl Shell { /// Returns the current directory on the specified device. /// @@ -54,4 +95,159 @@ impl Shell { let dir_ptr: *const Char16 = directory.map_or(ptr::null(), |x| x.as_ptr()); unsafe { (self.0.set_cur_dir)(fs_ptr.cast(), dir_ptr.cast()) }.to_result() } + + /// Gets the value of the specified environment variable + /// + /// # Arguments + /// + /// * `name` - The environment variable name of which to retrieve the + /// value. + /// + /// # Returns + /// + /// * `Some()` - &CStr16 containing the value of the + /// environment variable + /// * `None` - If environment variable does not exist + #[must_use] + pub fn var(&self, name: &CStr16) -> Option<&CStr16> { + let name_ptr: *const Char16 = name.as_ptr(); + let var_val = unsafe { (self.0.get_env)(name_ptr.cast()) }; + if var_val.is_null() { + None + } else { + unsafe { Some(CStr16::from_ptr(var_val.cast())) } + } + } + + /// Gets an iterator over the names of all environment variables + /// + /// # Returns + /// + /// * `Vars` - Iterator over the names of the environment variables + #[must_use] + pub fn vars(&self) -> Vars<'_, Self> { + let env_ptr = unsafe { (self.0.get_env)(ptr::null()) }; + Vars { + names: env_ptr.cast::(), + protocol: self, + _marker: PhantomData, + } + } + + /// Sets the environment variable + /// + /// # Arguments + /// + /// * `name` - The environment variable for which to set the value + /// * `value` - The new value of the environment variable + /// * `volatile` - Indicates whether the variable is volatile or + /// not + /// + /// # Returns + /// + /// * `Status::SUCCESS` - The variable was successfully set + pub fn set_var(&self, name: &CStr16, value: &CStr16, volatile: bool) -> Result { + let name_ptr: *const Char16 = name.as_ptr(); + let value_ptr: *const Char16 = value.as_ptr(); + unsafe { (self.0.set_env)(name_ptr.cast(), value_ptr.cast(), volatile.into()) }.to_result() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::collections::BTreeMap; + use alloc::vec::Vec; + use uefi::cstr16; + + struct ShellMock<'a> { + inner: BTreeMap<&'a CStr16, &'a CStr16>, + } + + impl<'a> ShellMock<'a> { + fn new(pairs: impl IntoIterator) -> ShellMock<'a> { + let mut inner_map = BTreeMap::new(); + for (name, val) in pairs.into_iter() { + inner_map.insert(name, val); + } + ShellMock { inner: inner_map } + } + } + impl<'a> ShellVarProvider for ShellMock<'a> { + fn var(&self, name: &CStr16) -> Option<&CStr16> { + if let Some(val) = self.inner.get(name) { + Some(*val) + } else { + None + } + } + } + + /// Testing Vars struct + #[test] + fn test_vars() { + // Empty Vars + let mut vars_mock = Vec::::new(); + vars_mock.extend_from_slice( + b"\0\0" + .into_iter() + .map(|&x| x as u16) + .collect::>() + .as_slice(), + ); + let mut vars = Vars { + names: vars_mock.as_ptr().cast(), + protocol: &ShellMock::new(Vec::new()), + _marker: PhantomData, + }; + + assert!(vars.next().is_none()); + + // One environment variable in Vars + let mut vars_mock = Vec::::new(); + vars_mock.extend_from_slice( + b"foo\0\0" + .into_iter() + .map(|&x| x as u16) + .collect::>() + .as_slice(), + ); + let vars = Vars { + names: vars_mock.as_ptr().cast(), + protocol: &ShellMock::new(Vec::from([(cstr16!("foo"), cstr16!("value"))])), + _marker: PhantomData, + }; + assert_eq!( + vars.collect::>(), + Vec::from([(cstr16!("foo"), Some(cstr16!("value")))]) + ); + + // Multiple environment variables in Vars + let mut vars_mock = Vec::::new(); + vars_mock.extend_from_slice( + b"foo1\0bar\0baz2\0\0" + .into_iter() + .map(|&x| x as u16) + .collect::>() + .as_slice(), + ); + + let vars = Vars { + names: vars_mock.as_ptr().cast(), + protocol: &ShellMock::new(Vec::from([ + (cstr16!("foo1"), cstr16!("value")), + (cstr16!("bar"), cstr16!("one")), + (cstr16!("baz2"), cstr16!("two")), + ])), + _marker: PhantomData, + }; + assert_eq!( + vars.collect::>(), + Vec::from([ + (cstr16!("foo1"), Some(cstr16!("value"))), + (cstr16!("bar"), Some(cstr16!("one"))), + (cstr16!("baz2"), Some(cstr16!("two"))) + ]) + ); + } }