diff --git a/embassy-stm32/src/usb/usb.rs b/embassy-stm32/src/usb/usb.rs index b9a16bbf1..31ab8f76d 100644 --- a/embassy-stm32/src/usb/usb.rs +++ b/embassy-stm32/src/usb/usb.rs @@ -80,10 +80,10 @@ impl interrupt::typelevel::Handler for InterruptHandl if istr.ctr() { let index = istr.ep_id() as usize; - CTR_TRIGGERED[index].store(true, Ordering::Relaxed); let mut epr = regs.epr(index).read(); if epr.ctr_rx() { + CTR_RX_TRIGGERED[index].store(true, Ordering::Relaxed); if index == 0 && epr.setup() { EP0_SETUP.store(true, Ordering::Relaxed); } @@ -91,6 +91,7 @@ impl interrupt::typelevel::Handler for InterruptHandl EP_OUT_WAKERS[index].wake(); } if epr.ctr_tx() { + CTR_TX_TRIGGERED[index].store(true, Ordering::Relaxed); //trace!("EP {} TX", index); EP_IN_WAKERS[index].wake(); } @@ -122,7 +123,8 @@ const USBRAM_ALIGN: usize = 4; static BUS_WAKER: AtomicWaker = AtomicWaker::new(); static EP0_SETUP: AtomicBool = AtomicBool::new(false); -static CTR_TRIGGERED: [AtomicBool; EP_COUNT] = [const { AtomicBool::new(false) }; EP_COUNT]; +static CTR_TX_TRIGGERED: [AtomicBool; EP_COUNT] = [const { AtomicBool::new(false) }; EP_COUNT]; +static CTR_RX_TRIGGERED: [AtomicBool; EP_COUNT] = [const { AtomicBool::new(false) }; EP_COUNT]; static EP_IN_WAKERS: [AtomicWaker; EP_COUNT] = [const { AtomicWaker::new() }; EP_COUNT]; static EP_OUT_WAKERS: [AtomicWaker; EP_COUNT] = [const { AtomicWaker::new() }; EP_COUNT]; static IRQ_RESET: AtomicBool = AtomicBool::new(false); @@ -209,10 +211,12 @@ mod btable { pub(super) fn write_in_rx(_index: usize, _addr: u16) {} pub(super) fn write_in_len_tx(index: usize, addr: u16, len: u16) { + assert_eq!(addr & 0b11, 0); USBRAM.mem(index * 2).write_value((addr as u32) | ((len as u32) << 16)); } pub(super) fn write_in_len_rx(index: usize, addr: u16, len: u16) { + assert_eq!(addr & 0b11, 0); USBRAM .mem(index * 2 + 1) .write_value((addr as u32) | ((len as u32) << 16)); @@ -640,22 +644,25 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { fn endpoint_set_enabled(&mut self, ep_addr: EndpointAddress, enabled: bool) { trace!("set_enabled {:?} {}", ep_addr, enabled); // This can race, so do a retry loop. - let reg = T::regs().epr(ep_addr.index() as _); - trace!("EPR before: {:04x}", reg.read().0); + let epr = T::regs().epr(ep_addr.index() as _); + trace!("EPR before: {:04x}", epr.read().0); match ep_addr.direction() { Direction::In => { loop { let want_stat = match enabled { false => Stat::DISABLED, - true => Stat::NAK, + true => match epr.read().ep_type() { + EpType::ISO => Stat::VALID, + _ => Stat::NAK, + }, }; - let r = reg.read(); + let r = epr.read(); if r.stat_tx() == want_stat { break; } let mut w = invariant(r); w.set_stat_tx(Stat::from_bits(r.stat_tx().to_bits() ^ want_stat.to_bits())); - reg.write_value(w); + epr.write_value(w); } EP_IN_WAKERS[ep_addr.index()].wake(); } @@ -665,18 +672,18 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { false => Stat::DISABLED, true => Stat::VALID, }; - let r = reg.read(); + let r = epr.read(); if r.stat_rx() == want_stat { break; } let mut w = invariant(r); w.set_stat_rx(Stat::from_bits(r.stat_rx().to_bits() ^ want_stat.to_bits())); - reg.write_value(w); + epr.write_value(w); } EP_OUT_WAKERS[ep_addr.index()].wake(); } } - trace!("EPR after: {:04x}", reg.read().0); + trace!("EPR after: {:04x}", epr.read().0); } async fn enable(&mut self) {} @@ -836,7 +843,8 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { if self.info.ep_type == EndpointType::Isochronous { // The isochronous endpoint does not change its `STAT_RX` field to `NAK` when receiving a packet. // Therefore, this instead waits until the `CTR` interrupt was triggered. - if matches!(stat, Stat::DISABLED) || CTR_TRIGGERED[index].load(Ordering::Relaxed) { + if matches!(stat, Stat::DISABLED) || CTR_RX_TRIGGERED[index].load(Ordering::Relaxed) { + assert!(matches!(stat, Stat::VALID | Stat::DISABLED)); Poll::Ready(stat) } else { Poll::Pending @@ -851,7 +859,7 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { }) .await; - CTR_TRIGGERED[index].store(false, Ordering::Relaxed); + CTR_RX_TRIGGERED[index].store(false, Ordering::Relaxed); if stat == Stat::DISABLED { return Err(EndpointError::Disabled); @@ -895,18 +903,17 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { if buf.len() > self.info.max_packet_size as usize { return Err(EndpointError::BufferOverflow); } - + trace!("WRITE WAITING, buf.len() = {}", buf.len()); let index = self.info.addr.index(); - - trace!("WRITE WAITING"); let stat = poll_fn(|cx| { EP_IN_WAKERS[index].register(cx.waker()); let regs = T::regs(); let stat = regs.epr(index).read().stat_tx(); if self.info.ep_type == EndpointType::Isochronous { - // The isochronous endpoint does not change its `STAT_RX` field to `NAK` when receiving a packet. + // The isochronous endpoint does not change its `STAT_TX` field to `NAK` after sending a packet. // Therefore, this instead waits until the `CTR` interrupt was triggered. - if matches!(stat, Stat::DISABLED) || CTR_TRIGGERED[index].load(Ordering::Relaxed) { + if matches!(stat, Stat::DISABLED) || CTR_TX_TRIGGERED[index].load(Ordering::Relaxed) { + assert!(matches!(stat, Stat::VALID | Stat::DISABLED)); Poll::Ready(stat) } else { Poll::Pending @@ -921,7 +928,7 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { }) .await; - CTR_TRIGGERED[index].store(false, Ordering::Relaxed); + CTR_TX_TRIGGERED[index].store(false, Ordering::Relaxed); if stat == Stat::DISABLED { return Err(EndpointError::Disabled); @@ -942,7 +949,6 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { self.write_data_double_buffered(buf, packet_buffer); - let regs = T::regs(); regs.epr(index).write(|w| { w.set_ep_type(convert_type(self.info.ep_type)); w.set_ea(self.info.addr.index() as _); @@ -955,7 +961,6 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { w.set_ctr_rx(true); // don't clear w.set_ctr_tx(true); // don't clear }); - trace!("WRITE OK"); Ok(())