Skip to content

Commit

Permalink
refactor: extract rusb-specific fns (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
louib authored Aug 31, 2024
1 parent ea22878 commit cd3007c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 96 deletions.
6 changes: 2 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ use configure::DeviceModeConfig;
use error::ChallengeResponseError;
use hmacmode::Hmac;
use otpmode::Aes128Block;
use rusb::{Context, UsbContext};
use rusb::UsbContext;
use sec::{crc16, CRC_RESIDUAL_OK};
use usb::{Flags, Frame};

use manager::{close_device, open_device, read_response, wait, write_frame};
use usb::{close_device, open_device, read_response, wait, write_frame, Context, Flags, Frame};

const VENDOR_ID: [u16; 3] = [
0x1050, // Yubico ( Yubikeys )
Expand Down
99 changes: 7 additions & 92 deletions src/manager.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use error::ChallengeResponseError;
use rusb::{request_type, Context, DeviceHandle, Direction, Recipient, RequestType, UsbContext};
use rusb::{request_type, Context, Direction, Recipient, RequestType, UsbContext};
use std::time::Duration;
use std::{slice, thread};
use usb::{Flags, Frame, HID_GET_REPORT, HID_SET_REPORT, REPORT_TYPE_FEATURE, WRITE_RESET_PAYLOAD};
use usb::{DeviceHandle, HID_GET_REPORT, HID_SET_REPORT, REPORT_TYPE_FEATURE};

pub fn open_device(
context: &mut Context,
bus_id: u8,
address_id: u8,
) -> Result<(DeviceHandle<Context>, Vec<u8>), ChallengeResponseError> {
) -> Result<(DeviceHandle, Vec<u8>), ChallengeResponseError> {
let devices = match context.devices() {
Ok(device) => device,
Err(_) => {
Expand Down Expand Up @@ -66,71 +65,27 @@ pub fn open_device(
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn close_device(
_handle: DeviceHandle<Context>,
_interfaces: Vec<u8>,
) -> Result<(), ChallengeResponseError> {
pub fn close_device(_handle: DeviceHandle, _interfaces: Vec<u8>) -> Result<(), ChallengeResponseError> {
Ok(())
}

#[cfg(not(any(target_os = "macos", target_os = "windows")))]
pub fn close_device(handle: DeviceHandle<Context>, interfaces: Vec<u8>) -> Result<(), ChallengeResponseError> {
pub fn close_device(handle: DeviceHandle, interfaces: Vec<u8>) -> Result<(), ChallengeResponseError> {
for interface in interfaces {
handle.release_interface(interface)?;
handle.attach_kernel_driver(interface)?;
}
Ok(())
}

pub fn wait<F: Fn(Flags) -> bool>(
handle: &mut DeviceHandle<Context>,
f: F,
buf: &mut [u8],
) -> Result<(), ChallengeResponseError> {
loop {
read(handle, buf)?;
let flags = Flags::from_bits_truncate(buf[7]);
if flags.contains(Flags::SLOT_WRITE_FLAG) || flags.is_empty() {
// Should store the version
}

if f(flags) {
return Ok(());
}
thread::sleep(Duration::new(0, 1000000));
}
}

pub fn read(handle: &mut DeviceHandle<Context>, buf: &mut [u8]) -> Result<usize, ChallengeResponseError> {
pub fn read(handle: &mut DeviceHandle, buf: &mut [u8]) -> Result<usize, ChallengeResponseError> {
assert_eq!(buf.len(), 8);
let reqtype = request_type(Direction::In, RequestType::Class, Recipient::Interface);
let value = REPORT_TYPE_FEATURE << 8;
Ok(handle.read_control(reqtype, HID_GET_REPORT, value, 0, buf, Duration::new(2, 0))?)
}

pub fn write_frame(handle: &mut DeviceHandle<Context>, frame: &Frame) -> Result<(), ChallengeResponseError> {
let mut data = unsafe { slice::from_raw_parts(frame as *const Frame as *const u8, 70) };

let mut seq = 0;
let mut buf = [0; 8];
while !data.is_empty() {
let (a, b) = data.split_at(7);

if seq == 0 || b.is_empty() || a.iter().any(|&x| x != 0) {
let mut packet = [0; 8];
(&mut packet[..7]).copy_from_slice(a);

packet[7] = Flags::SLOT_WRITE_FLAG.bits() + seq;
wait(handle, |x| !x.contains(Flags::SLOT_WRITE_FLAG), &mut buf)?;
raw_write(handle, &packet)?
}
data = b;
seq += 1
}
Ok(())
}

pub fn raw_write(handle: &mut DeviceHandle<Context>, packet: &[u8]) -> Result<(), ChallengeResponseError> {
pub fn raw_write(handle: &mut DeviceHandle, packet: &[u8]) -> Result<(), ChallengeResponseError> {
let reqtype = request_type(Direction::Out, RequestType::Class, Recipient::Interface);
let value = REPORT_TYPE_FEATURE << 8;
if handle.write_control(reqtype, HID_SET_REPORT, value, 0, &packet, Duration::new(2, 0))? != 8 {
Expand All @@ -139,43 +94,3 @@ pub fn raw_write(handle: &mut DeviceHandle<Context>, packet: &[u8]) -> Result<()
Ok(())
}
}

/// Reset the write state after a read.
pub fn write_reset(handle: &mut DeviceHandle<Context>) -> Result<(), ChallengeResponseError> {
raw_write(handle, &WRITE_RESET_PAYLOAD)?;
let mut buf = [0; 8];
wait(handle, |x| !x.contains(Flags::SLOT_WRITE_FLAG), &mut buf)?;
Ok(())
}

pub fn read_response(
handle: &mut DeviceHandle<Context>,
response: &mut [u8],
) -> Result<usize, ChallengeResponseError> {
let mut r0 = 0;
wait(
handle,
|f| f.contains(Flags::RESP_PENDING_FLAG),
&mut response[..8],
)?;
r0 += 7;
loop {
if read(handle, &mut response[r0..r0 + 8])? < 8 {
break;
}
let flags = Flags::from_bits_truncate(response[r0 + 7]);
if flags.contains(Flags::RESP_PENDING_FLAG) {
let seq = response[r0 + 7] & 0b00011111;
if r0 > 0 && seq == 0 {
// If the sequence number is 0, and we have read at
// least one packet, stop.
break;
}
} else {
break;
}
r0 += 7;
}
write_reset(handle)?;
Ok(r0)
}
90 changes: 90 additions & 0 deletions src/usb.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
use rusb::{Context as RUSBContext, DeviceHandle as RUSBDeviceHandle};
use std::time::Duration;
use std::{slice, thread};

use config::Command;
use error::ChallengeResponseError;
use manager::{raw_write, read};
use sec::crc16;

pub use manager::{close_device, open_device};

/// The size of the payload when writing a request to the usb interface.
pub(crate) const PAYLOAD_SIZE: usize = 64;
/// The size of the response after writing a request to the usb interface.
Expand Down Expand Up @@ -42,3 +50,85 @@ impl Frame {
f
}
}

pub type Context = RUSBContext;

pub(crate) type DeviceHandle = RUSBDeviceHandle<Context>;

pub fn write_frame(handle: &mut DeviceHandle, frame: &Frame) -> Result<(), ChallengeResponseError> {
let mut data = unsafe { slice::from_raw_parts(frame as *const Frame as *const u8, 70) };

let mut seq = 0;
let mut buf = [0; 8];
while !data.is_empty() {
let (a, b) = data.split_at(7);

if seq == 0 || b.is_empty() || a.iter().any(|&x| x != 0) {
let mut packet = [0; 8];
(&mut packet[..7]).copy_from_slice(a);

packet[7] = Flags::SLOT_WRITE_FLAG.bits() + seq;
wait(handle, |x| !x.contains(Flags::SLOT_WRITE_FLAG), &mut buf)?;
raw_write(handle, &packet)?;
}
data = b;
seq += 1
}
Ok(())
}

pub fn wait<F: Fn(Flags) -> bool>(
handle: &mut DeviceHandle,
f: F,
buf: &mut [u8],
) -> Result<(), ChallengeResponseError> {
loop {
read(handle, buf)?;
let flags = Flags::from_bits_truncate(buf[7]);
if flags.contains(Flags::SLOT_WRITE_FLAG) || flags.is_empty() {
// Should store the version
}

if f(flags) {
return Ok(());
}
thread::sleep(Duration::new(0, 1000000));
}
}

/// Reset the write state after a read.
pub fn write_reset(handle: &mut DeviceHandle) -> Result<(), ChallengeResponseError> {
raw_write(handle, &WRITE_RESET_PAYLOAD)?;
let mut buf = [0; 8];
wait(handle, |x| !x.contains(Flags::SLOT_WRITE_FLAG), &mut buf)?;
Ok(())
}

pub fn read_response(handle: &mut DeviceHandle, response: &mut [u8]) -> Result<usize, ChallengeResponseError> {
let mut r0 = 0;
wait(
handle,
|f| f.contains(Flags::RESP_PENDING_FLAG),
&mut response[..8],
)?;
r0 += 7;
loop {
if read(handle, &mut response[r0..r0 + 8])? < 8 {
break;
}
let flags = Flags::from_bits_truncate(response[r0 + 7]);
if flags.contains(Flags::RESP_PENDING_FLAG) {
let seq = response[r0 + 7] & 0b00011111;
if r0 > 0 && seq == 0 {
// If the sequence number is 0, and we have read at
// least one packet, stop.
break;
}
} else {
break;
}
r0 += 7;
}
write_reset(handle)?;
Ok(r0)
}

0 comments on commit cd3007c

Please sign in to comment.