From ab5c0692f946fe7bf3e6cac9b6dc44f362acf9b2 Mon Sep 17 00:00:00 2001
From: Yann Bouteiller <yann.bouteiller@polymtl.ca>
Date: Tue, 25 Jan 2022 17:03:10 -0500
Subject: [PATCH] debugging partial_to_dict usage

---
 readme/tuto_library.md         | 21 ++++++++++++---------
 tmrl/custom/custom_memories.py | 12 ++++++------
 tmrl/tuto/tuto.py              | 12 ++++++------
 3 files changed, 24 insertions(+), 21 deletions(-)

diff --git a/readme/tuto_library.md b/readme/tuto_library.md
index a723f09..32bc9a4 100644
--- a/readme/tuto_library.md
+++ b/readme/tuto_library.md
@@ -44,7 +44,7 @@ The full script for this tutorial is available [here](https://github.com/trackma
 ## Tools
 
 ### partial() method
-We use this method a lot in `tmrl`, it enables partially instantiating a class.
+We use this method a lot in `tmrl`, it enables partially initializing the kwargs of a class.
 Import this method in your script:
 
 ```python
@@ -63,7 +63,7 @@ my_partially_instantiated_class = partial(my_class,
 And the partially instantiated class can then be fully instantiated as:
 
 ```python
-my_object = my_partially_instantiated_class(missing_args_and_kwargs)
+my_object = my_partially_instantiated_class(missing_kwargs)
 ```
 
 ### Constants
@@ -235,8 +235,8 @@ import tmrl.config.config_constants as cfg  # constants from the config.json fil
 class RolloutWorker:
     def __init__(
             self,
-            env_cls,  # class of the Gym environment
-            actor_module_cls,  # class of a module containing the policy
+            env_cls=None,  # class of the Gym environment
+            actor_module_cls=None,  # class of a module containing the policy
             sample_compressor: callable = None,  # compressor for sending samples over the Internet
             device="cpu",  # device on which the policy is running
             server_ip=None,  # ip of the central server
@@ -753,8 +753,8 @@ Thus, we will use the action buffer length as an additional argument to our cust
 
 ```python
     def __init__(self,
-                 device,
-                 nb_steps,
+                 device=None,
+                 nb_steps=None,
                  obs_preprocessor: callable = None,
                  sample_preprocessor: callable = None,
                  memory_size=1000000,
@@ -1010,9 +1010,9 @@ class MyTrainingAgent(TrainingAgent):
     model_nograd = cached_property(lambda self: no_grad(copy_shared(self.model)))
     
     def __init__(self,
-                 observation_space,
-                 action_space,
-                 device,
+                 observation_space=None,
+                 action_space=None,
+                 device=None,
                  model_cls=MyActorCriticModule,  # an actor-critic module, encapsulating our ActorModule
                  gamma=0.99,  # discount factor
                  polyak=0.995,  # exponential averaging factor for the target critic
@@ -1244,6 +1244,9 @@ my_trainer.run_with_wandb(entity=my_wandb_entity,
                           key=my_wandb_key)
 ```
 
+_(**WARNING**: when using `run_with_wandb`, make sure all the partially instantiated classes that are part of the `Trainer` have kwargs only, no args, otherwise you will get an error complaining about invalid keywords.
+When it does not make sense to have default values, just set the default values to `None` as done in, e.g., `MyMemoryDataloading`)_
+
 But as for the `RolloutWorker`, this would block the code here until all `epochs` are complete, which in itself would require the `RolloutWorker` to also be running.
 
 In fact, the `RolloutWorker`, `Trainer` and `Server` are best run in separate terminals (see TrackMania) because currently they are all quite verbose.
diff --git a/tmrl/custom/custom_memories.py b/tmrl/custom/custom_memories.py
index b45b63c..aa9751a 100644
--- a/tmrl/custom/custom_memories.py
+++ b/tmrl/custom/custom_memories.py
@@ -70,8 +70,8 @@ def replace_hist_before_done(hist, done_idx_in_hist):
 
 class MemoryTMNF(MemoryDataloading):
     def __init__(self,
-                 memory_size,
-                 batch_size,
+                 memory_size=None,
+                 batch_size=None,
                  dataset_path="",
                  imgs_obs=4,
                  act_buf_len=1,
@@ -219,8 +219,8 @@ def append_buffer(self, buffer):
 
 class TrajMemoryTMNF(TrajMemoryDataloading):
     def __init__(self,
-                 memory_size,
-                 batch_size,
+                 memory_size=None,
+                 batch_size=None,
                  dataset_path="",
                  imgs_obs=4,
                  act_buf_len=1,
@@ -358,8 +358,8 @@ def append_buffer(self, buffer):
 
 class MemoryTM2020(MemoryDataloading):  # TODO: reset transitions
     def __init__(self,
-                 memory_size,
-                 batch_size,
+                 memory_size=None,
+                 batch_size=None,
                  dataset_path="",
                  imgs_obs=4,
                  act_buf_len=1,
diff --git a/tmrl/tuto/tuto.py b/tmrl/tuto/tuto.py
index 49eeb6e..716fe1f 100644
--- a/tmrl/tuto/tuto.py
+++ b/tmrl/tuto/tuto.py
@@ -304,9 +304,9 @@ def my_observation_preprocessor(obs):
 
 class MyMemoryDataloading(MemoryDataloading):
     def __init__(self,
-                 act_buf_len,
-                 device,
-                 nb_steps,
+                 act_buf_len=None,
+                 device=None,
+                 nb_steps=None,
                  obs_preprocessor: callable = None,
                  sample_preprocessor: callable = None,
                  memory_size=1000000,
@@ -457,9 +457,9 @@ class MyTrainingAgent(TrainingAgent):
     model_nograd = cached_property(lambda self: no_grad(copy_shared(self.model)))
 
     def __init__(self,
-                 observation_space,
-                 action_space,
-                 device,
+                 observation_space=None,
+                 action_space=None,
+                 device=None,
                  model_cls=MyActorCriticModule,  # an actor-critic module, encapsulating our ActorModule
                  gamma=0.99,  # discount factor
                  polyak=0.995,  # exponential averaging factor for the target critic