Проблема в обучении агента Unity MLAgents
У меня есть сцена с трассой и машиной, и я хочу чтобы машина научилась сама проходить трассу. Всё работает правильно, только машина странно обучается. Такое чувство, что она не использует ray perception sensor 3d. На сцене есть чекпоинты, если машина проходит их в правильном направлении, то она получает награду +0.01 , а если в неправильном то -0.01. Скрипт агента:
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using System.Runtime.CompilerServices;
using UnityEngine.UI;
using TMPro;
using System;
using Random = UnityEngine.Random;
public class Car : Agent
{
[SerializeField] private TrackCheckpoints trackCheckpoints;
[SerializeField] private Transform spwanPosition;
private MSVehicleControllerFree carDriver; //контроллер машины
private MSSceneControllerFree sceneController; //контроллер сцены, нужен для вывода скорости
public GameObject scene;
public Text distance_text;
double distanceTravelled = 0;
Vector3 lastPosition;
//в корутине даём награду за расстоянии которое проехала машина
IEnumerator distance_reward()
{
while (true)
{
yield return new WaitForSeconds(1.0f);
AddReward((float)(0.001 * distanceTravelled));
distanceTravelled = 0;
}
}
public void Awake()
{
carDriver = GetComponent<MSVehicleControllerFree>();
lastPosition = transform.position;
sceneController = scene.GetComponent<MSSceneControllerFree>();
StartCoroutine(distance_reward());
}
private void FixedUpdate()
{
//штрафуем машину за низкую скорость
if (sceneController.car_speed < 10) {
AddReward(-0.00001f);
}
//изменяем дистацию которая проехала машина
if (sceneController.car_speed >= 0) {
distanceTravelled += Vector3.Distance(transform.position, lastPosition);
}
else if (sceneController.car_speed < 0) {
distanceTravelled -= Vector3.Distance(transform.position, lastPosition);
}
lastPosition = transform.position;
distanceTravelled = Math.Round(distanceTravelled, 3);
distance_text.text = "Distance: " + distanceTravelled;
//награждаем машину за скорость
AddReward((float)(0.000002 * sceneController.car_speed));
}
public override void OnEpisodeBegin()
{
//спавн машины
transform.position = spwanPosition.position + new Vector3 (Random.Range(-1f,+1f), 0, Random.Range(-1f, +1f));
transform.forward = spwanPosition.forward;
trackCheckpoints.ResetCheckpoints(transform);
carDriver.GetComponent<Rigidbody>().velocity = Vector3.zero;
//сбрасываем дистанцию которая прошла машина
lastPosition = transform.position;
distanceTravelled = 0;
StartCoroutine(distance_reward());
}
public override void CollectObservations(VectorSensor sensor)
{
Vector3 checkpointForward = trackCheckpoints.GetNextCheckpoint(transform).transform.forward;
float directionDot = Vector3.Dot(transform.forward, checkpointForward);
//передаем на вход ИИ направление следующего чекпоинта, скорости машины, дистанции которую прошла в течении 1 секунды
sensor.AddObservation(directionDot);
sensor.AddObservation(sceneController.car_speed);
sensor.AddObservation((float)distanceTravelled);
}
public override void OnActionReceived(ActionBuffers actions)
{
float forwardAmount = 0f;
float turnAmount = 0f;
switch (actions.DiscreteActions[0])
{
case 0: forwardAmount = 0f; break;
case 1: forwardAmount = +1f; break;
case 2: forwardAmount = -1f; break;
}
switch (actions.DiscreteActions[1])
{
case 0: turnAmount = 0f; break;
case 1: turnAmount = -1f; break;
case 2: turnAmount = +1f; break;
}
sceneController.Update_Controls(forwardAmount, turnAmount);
carDriver.GetComponent<MSVehicleControllerFree>().Update_Controls(forwardAmount, turnAmount);
//штрафуем машину каждый шаг
AddReward(-0.000007f);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
int forwardAction = 0;
if(Input.GetKey(KeyCode.Keypad8)) forwardAction = 1;
if (Input.GetKey(KeyCode.Keypad5)) forwardAction = 2;
int turnAction = 0;
if (Input.GetKey(KeyCode.Keypad4)) turnAction = 1;
if (Input.GetKey(KeyCode.Keypad6)) turnAction = 2;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = forwardAction;
discreteActions[1] = turnAction;
}
//даём отрицательную награду если машина врезалась в стену
private void OnTriggerEnter(Collider other)
{
if (other.gameObject.TryGetComponent<Wall>(out Wall wall))
{
AddReward(-0.01f);
EndEpisode();
}
}
}
Источник: Stack Overflow на русском