728x90
Author: KangchanRoh
Team: Reinforcement Learning Team @ CAI Lab
Date: 2022/11/30
train.py
- train.py를 실행하면 main() 함수가 실행됨
- Line 24. from crowd_nav.configs.config import Config에서 다음 항목에 관한 특성과 설정들을 다룰 수 있다
- env
- reward function
- human
- robot
- noise
- orca, social force
- ppo
- SRNN
- training details
- Line 93. actor_critic 식별자로 아래 그림의 DSRNN 모델이 구현된 Policy 클래스 객체 생성
- Line 99. rollouts 식별자로 학습 데이터 저장소 역할의 RolloutStorage 클래스 객체 생성
- Line 107. 학습이 중단되고 다시 실행할 때 기존 학습 체크포인트에서 계속할 것인 지에 대한 여부 처리 (Default = False)
- Line 117. agent 식별자로 강화학습 알고리즘 PPO 클래스 객체 생성
- Line 130-135. 환경 초기화, rollouts에 obs 복사
- Line 145. for 문, num_updates만큼 update
- Line 147. Learning rate decaying에 대한 여부 처리
- Line 151-183. for 문, config.ppo.num_steps만큼 actor_critic(DSRNN)과 환경이 상호작용하며 action 샘플하여 rollouts에 저장
- Line 185-204. rollouts에 저장된 데이터로 return값들을 계산하고 모델 업데이트
pytorchBaselines/a2c_ppo_acktr/srnn_model.py
- 아래 그림과 코드를 켜고 보면서 보면 도움이 됨
- Line 8. RNNBase 클래스, RNN에 대해 잘 모르지만 RNN네트워크를 생성하는 데에 요해지는 기본적인 요소들로 추측.
- Line 102. HumanNodeRNN 클래스, RNNBase 클래스를 상속 받는다. 위 그림에서 $\operatorname{R}_N$으로 쓰인다. input과 hidden state를 forward 하면 다음 hidden state와 어떤 출력 x 를 출력한다.
- Line 171. HumanHumanRNN 클래스, RNNBase 클래스를 상속 받는다. 위 그림에서 $\operatorname{R}_S,\ \operatorname{R}_T$로 쓰인다. input과 hidden state를 forward 하면 다음 hidden state와 어떤 출력 x 를 출력한다.
- st-graph에서 노드의 정보는 HumanNodeRNN 클래스로 다루고 노드 간 정보는 HumanHumanRNN 클래스로 다룬다.
- Line 213. EdgeAttention 클래스, 위 그림에서 Attention Module에 해당한다. temporal edge와 spatial edge을 forward하면 weighted된 spatial edge list와 attention weight를 출력한다.
- Line 219. SRNN 클래스, 위 그림 모델 전체이다. input으로 inputs(temporal edge+spatial edge)과 $\operatorname{R}_N,\ \operatorname{R}_S,\ \operatorname{R}_T$의 hidden state가 forward 되면 input에는 temporal edge와 spatial edge가 함께 포함되어있어 reshape와 slicing을 거친 후 위 그림과 같은 프로세스가 진행되어 critic value, action, 다시 RNN들의 input으로 들어갈 hidden state를 출력한다.
pytorchBaselines/a2c_ppo_acktr/storage.py
- class RolloutStorage 자료구조는 다음과 같다.
- 더불어 데이터 insert, 학습 시 필요한 return 값 계산, 업데이트 후 메모리 초기화 등 필요한 관련 메서드가 선언되어있다.
- 참고로 num_step에 +1이 붙는 이유는 그 곳의 0번째 index에는 최초 state에서의 값들이 먼저 들어가기 때문이다
반응형