Skip to content

Commit

Permalink
Updated the library tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed Jan 11, 2023
1 parent 990cf31 commit b2ef3f9
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 24 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ This works on any track, using any (sensible) camera configuration.
{
"ENV": {
"RTGYM_INTERFACE": "TM20FULL", // TrackMania 2020 with full screenshots
"WINDOW_WIDTH": 256, // width of the game window and screenshots (min: 256)
"WINDOW_HEIGHT": 128, // height of the game window and screenshots (min: 128)
"WINDOW_WIDTH": 256, // width of the game window (min: 256)
"WINDOW_HEIGHT": 128, // height of the game window (min: 128)
"SLEEP_TIME_AT_RESET": 1.5, // the environment sleeps for this amount of time after each reset
"IMG_HIST_LEN": 4, // length of the history of images in observations (set to 1 for RNNs)
"IMG_WIDTH": 64, // actual (resized) width of the images in observations
"IMG_HEIGHT": 64, // actual (resized) height of the images in observations
"IMG_GRAYSCALE": true, // true for grayscale images, false for color images
"RTGYM_CONFIG": {
"time_step_duration": 0.05, // duration of a time step
Expand Down
69 changes: 54 additions & 15 deletions readme/tuto_library.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ If you think an option to install `tmrl` without support for TrackMania should e
- [Tools](#tools)
- [partial() method](#partial-method)
- [Constants](#constants)
- [Security](#network-and-internet-security)
- [Server](#server)
- [Environment](#environment)
- [RolloutWorker](#rollout-workers)
Expand Down Expand Up @@ -95,6 +96,35 @@ print(f"Run name: {cfg.RUN_NAME}")
_(NB: read the code for finding available constants)_


## Network and Internet security

In the context of this tutorial, everything will happen on localhost and thus Internet security is not really a concern.
In real applications though, you may want to have several `tmrl` entities communicating over the Internet.

Security-wise, `tmrl` is based on [tlspyo](https://github.com/MISTLab/tls-python-object).
This enables authentication and encryption of your communications via a TLS key, that you first need to generate if you wish to use this option (see the `tlspyo` documentation for doing so in a couple easy steps).

For your safety, please carefully consider using this feature when training over a public network, as this is to protect you against possible attacks from malicious users.

In the context of this tutorial, we will not enable this feature as we suppose our local network to be safe.
Instead, we will just rely on the weak password security that is always enabled in `tmrl`:

```python
security = None # change this to "TLS" for TLS encryption (requires a TLS key)
password = "A Secure Password" # change this to a random password
```

_(NB: When training over a public network, you should use both a secure password and TLS encryption.
Please read the `tlspyo` security instructions to understand why this is important.)_

In this tutorial, we will use our localhost IP (i,e, `127.0.0.1`) and port `6666` for communication between our `tmrl` entities.
In an Internet application, you would adapt these to your network setup:

```python
server_ip = "127.0.0.1" # IP of the machine where we will run our TMRL Server
server_port = 6666 # port through which our Server will be accessible
```


## Server

Expand All @@ -115,7 +145,15 @@ Instantiating a `Server` object is straightforward:
```python
from tmrl.networking import Server

my_server = Server()
# tmrl Server

# (NB: When you omit arguments,
# tmrl retrieves the default in your config.json file.
# Read the documentation of each class for more info.)

my_server = Server(security=security,
password=password,
port=server_port)
```

As soon as the server is instantiated, it listens for incoming connections from the `Trainer` and the `RolloutWorkers`.
Expand Down Expand Up @@ -256,6 +294,8 @@ class RolloutWorker:
actor_module_cls=None, # class of a module containing the policy
sample_compressor: callable = None, # compressor for sending samples over the Internet
server_ip=None, # ip of the central server
server_port=cfg.PORT, # port of the server
password=cfg.PASSWORD, # password of the server
max_samples_per_episode=np.inf, # if an episode gets longer than this, it is reset
model_path=cfg.MODEL_PATH_WORKER, # path where a local copy of the policy will be stored
obs_preprocessor: callable = None, # utility for modifying observations returned by the environment
Expand Down Expand Up @@ -459,17 +499,12 @@ device = "cpu"

`RolloutWorkers` behave as Internet clients, and must therefore know the IP address of the `Server` to be able to communicate.
Typically, the `Server` lives on a machine to which you can forward ports behind your router.
Default ports to forward are `55556` (for `RolloutWorkers`) and `55555` (for the `Trainer`). If these ports are not available for you, you can change them in the `config.json` file.

It is of course possible to work locally by hosting the `Server`, `RolloutWorkers`, and `Trainer` on localhost.
This is done by setting the `Server` IP as the localhost IP, i.e., `"127.0.0.1"`:

```python
server_ip = "127.0.0.1"
```
Nevertheless, it is of course possible to work locally by hosting the `Server`, `RolloutWorkers`, and `Trainer` on localhost.
This is done by setting the `Server` IP as the localhost IP, i.e., `"127.0.0.1"`, which we did.
_(NB: We have set the values for `server_ip` and `server_port` earlier in this tutorial.)_

In the current iteration of `tmrl`, samples are gathered locally in a buffer by the `RolloutWorker` and are sent to the `Server` only at the end of an episode.

In case your Gym environment is never `terminated` (or only after too long), `tmrl` enables forcing reset after a time-steps threshold.
For instance, let us say we don't want an episode to last more than 1000 time-steps:

Expand Down Expand Up @@ -539,6 +574,8 @@ my_worker = RolloutWorker(
sample_compressor=sample_compressor,
device=device,
server_ip=server_ip,
server_port=server_port,
password=password,
max_samples_per_episode=max_samples_per_episode,
model_path=model_path,
model_path_history=model_path_history,
Expand Down Expand Up @@ -584,6 +621,8 @@ class Trainer:
def __init__(self,
training_cls=cfg_obj.TRAINER,
server_ip=cfg.SERVER_IP_FOR_TRAINER,
server_port=cfg.PORT,
password=cfg.PASSWORD,
model_path=cfg.MODEL_PATH_TRAINER,
checkpoint_path=cfg.CHECKPOINT_PATH,
dump_run_instance_fn: callable = None,
Expand All @@ -593,13 +632,11 @@ class Trainer:
### Networking and files

`server_ip` is the public IP address of the `Server`.
Since both the `Trainer` and `RolloutWorker` will run on the same machine as the `Server` in this tutorial, the `server_ip` will also be localhost here, i.e., `"127.0.0.1"`:

```python
server_ip = "127.0.0.1"
```
Since both the `Trainer` and `RolloutWorker` will run on the same machine as the `Server` in this tutorial, the `server_ip` will also be localhost here, i.e., `"127.0.0.1"`.
The `server_port` and the `password` are still valid for our `Trainer`.

`model_path` is similar to the one of the `RolloutWorker`. The trainer will keep a local copy of its model that acts as a saving file.
`model_path` is similar to the one of the `RolloutWorker`.
The trainer will keep a local copy of its model that acts as a saving file.

`checkpoints_path` is similar, but this will save the whole `training_cls` instance (including the replay buffer).
If set to `None`, training will not be checkpointed.
Expand Down Expand Up @@ -1231,6 +1268,8 @@ from tmrl.networking import Trainer
my_trainer = Trainer(
training_cls=training_cls,
server_ip=server_ip,
server_port=server_port,
password=password,
model_path=model_path,
checkpoint_path=checkpoints_path) # None for not saving training checkpoints
```
Expand Down
2 changes: 1 addition & 1 deletion tmrl/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self,
password (str): tlspyo password
local_port (int): tlspyo local communication port
header_size (int): tlspyo header size (bytes)
security (str): tlspyo security type (None or "TLS")
security (Union[str, None]): tlspyo security type (None or "TLS")
keys_dir (str): tlspyo credentials directory
max_workers (int): max number of accepted workers
"""
Expand Down
2 changes: 0 additions & 2 deletions tmrl/training_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from pandas import DataFrame

# local imports
import tmrl.config.config_constants as cfg
from tmrl.util import pandas_dict

import logging
# import pybullet_envs


__docformat__ = "google"
Expand Down
18 changes: 14 additions & 4 deletions tmrl/tuto/tuto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,19 @@

CRC_DEBUG = False

# === Networking parameters ============================================================================================

security = None
password = cfg.PASSWORD

server_ip = "127.0.0.1"
server_port = 6666


# === Server ===========================================================================================================

if __name__ == "__main__":
my_server = Server()
my_server = Server(security=security, password=password, port=server_port)


# === Environment ======================================================================================================
Expand Down Expand Up @@ -242,7 +251,6 @@ def my_sample_compressor(prev_act, obs, rew, terminated, truncated, info):

# Networking

server_ip = "127.0.0.1"
max_samples_per_episode = 1000


Expand All @@ -265,6 +273,8 @@ def my_sample_compressor(prev_act, obs, rew, terminated, truncated, info):
sample_compressor=sample_compressor,
device=device,
server_ip=server_ip,
server_port=server_port,
password=password,
max_samples_per_episode=max_samples_per_episode,
model_path=model_path,
model_path_history=model_path_history,
Expand All @@ -278,8 +288,6 @@ def my_sample_compressor(prev_act, obs, rew, terminated, truncated, info):

# --- Networking and files ---

server_ip = "127.0.0.1"

weights_folder = cfg.WEIGHTS_FOLDER # path to the weights folder
checkpoints_folder = cfg.CHECKPOINTS_FOLDER
my_run_name = "tutorial"
Expand Down Expand Up @@ -596,6 +604,8 @@ def train(self, batch):
my_trainer = Trainer(
training_cls=training_cls,
server_ip=server_ip,
server_port=server_port,
password=password,
model_path=model_path,
checkpoint_path=checkpoints_path) # None for not saving training checkpoints

Expand Down

0 comments on commit b2ef3f9

Please sign in to comment.