Upgrade pre-commit tools (#12)
This commit is contained in:
Родитель
1bde547068
Коммит
9a5a1418ff
|
@ -1,30 +1,32 @@
|
|||
repos:
|
||||
- repo: https://github.com/python/black
|
||||
rev: 19.3b0
|
||||
rev: 22.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
exclude: >
|
||||
(?x)^(
|
||||
.*_pb2.py|
|
||||
.*_pb2.pyi|
|
||||
.*_pb2_grpc.py
|
||||
)$
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.761
|
||||
rev: v0.931
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy-ml-agents
|
||||
files: "ml-agents/.*"
|
||||
args: [--ignore-missing-imports, --disallow-incomplete-defs]
|
||||
args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional]
|
||||
additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools]
|
||||
- id: mypy
|
||||
name: mypy-ml-agents-envs
|
||||
files: "ml-agents-envs/.*"
|
||||
# Exclude protobuf files and don't follow them when imported
|
||||
exclude: ".*_pb2.py"
|
||||
args: [--ignore-missing-imports, --disallow-incomplete-defs]
|
||||
|
||||
args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional]
|
||||
additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools]
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 3.8.1
|
||||
rev: 3.9.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: >
|
||||
|
@ -36,7 +38,7 @@ repos:
|
|||
additional_dependencies: [flake8-comprehensions==3.2.2, flake8-tidy-imports==4.1.0, flake8-bugbear==20.1.4]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.7.0
|
||||
rev: v2.31.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py3-plus, --py36-plus]
|
||||
|
@ -47,7 +49,7 @@ repos:
|
|||
)$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.5.0
|
||||
rev: v4.1.0
|
||||
hooks:
|
||||
- id: mixed-line-ending
|
||||
exclude: >
|
||||
|
@ -68,12 +70,12 @@ repos:
|
|||
exclude: \.yamato/.*
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.4.2
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: python-check-mock-methods
|
||||
|
||||
- repo: https://github.com/mattlqx/pre-commit-search-and-replace
|
||||
rev: v1.0.3
|
||||
rev: v1.0.5
|
||||
hooks:
|
||||
- id: search-and-replace
|
||||
types: [markdown]
|
||||
|
|
|
@ -253,9 +253,15 @@ class UnityPettingzooBaseEnv:
|
|||
self._current_action[behavior_name] = self._create_empty_actions(
|
||||
behavior_name, len(current_batch[0])
|
||||
)
|
||||
agents, obs, dones, rewards, cumulative_rewards, infos, id_map = _unwrap_batch_steps(
|
||||
current_batch, behavior_name
|
||||
)
|
||||
(
|
||||
agents,
|
||||
obs,
|
||||
dones,
|
||||
rewards,
|
||||
cumulative_rewards,
|
||||
infos,
|
||||
id_map,
|
||||
) = _unwrap_batch_steps(current_batch, behavior_name)
|
||||
self._live_agents += agents
|
||||
self._agents += agents
|
||||
self._observations.update(obs)
|
||||
|
|
|
@ -137,7 +137,7 @@ def download_and_extract_zip(url: str, name: str) -> None:
|
|||
try:
|
||||
request = urllib.request.urlopen(url, timeout=30)
|
||||
except urllib.error.HTTPError as e: # type: ignore
|
||||
e.msg += " " + url
|
||||
e.reason = f"{e.reason} {url}"
|
||||
raise
|
||||
zip_size = int(request.headers["content-length"])
|
||||
zip_file_path = os.path.join(zip_dir, str(uuid.uuid4()) + ".zip")
|
||||
|
@ -193,7 +193,7 @@ def load_remote_manifest(url: str) -> Dict[str, Any]:
|
|||
try:
|
||||
request = urllib.request.urlopen(url, timeout=30)
|
||||
except urllib.error.HTTPError as e: # type: ignore
|
||||
e.msg += " " + url
|
||||
e.reason = f"{e.reason} {url}"
|
||||
raise
|
||||
manifest_path = os.path.join(tmp_dir, str(uuid.uuid4()) + ".yaml")
|
||||
with open(manifest_path, "wb") as manifest:
|
||||
|
|
|
@ -21,7 +21,7 @@ class SideChannelManager:
|
|||
try:
|
||||
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
|
||||
offset += 16
|
||||
message_len, = struct.unpack_from("<i", data, offset)
|
||||
(message_len,) = struct.unpack_from("<i", data, offset)
|
||||
offset = offset + 4
|
||||
message_data = data[offset : offset + message_len]
|
||||
offset = offset + message_len
|
||||
|
@ -63,7 +63,7 @@ class SideChannelManager:
|
|||
|
||||
@staticmethod
|
||||
def _get_side_channels_dict(
|
||||
side_channels: Optional[List[SideChannel]]
|
||||
side_channels: Optional[List[SideChannel]],
|
||||
) -> Dict[uuid.UUID, SideChannel]:
|
||||
"""
|
||||
Converts a list of side channels into a dictionary of channel_id to SideChannel
|
||||
|
|
|
@ -59,5 +59,6 @@ setup(
|
|||
"numpy==1.21.2",
|
||||
],
|
||||
python_requires=">=3.7.2,<3.9.10",
|
||||
cmdclass={"verify": VerifyVersionCommand},
|
||||
# TODO: Remove this once mypy stops having spurious setuptools issues.
|
||||
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
|
||||
)
|
||||
|
|
|
@ -42,11 +42,11 @@ class BehaviorIdentifiers(NamedTuple):
|
|||
|
||||
def create_name_behavior_id(name: str, team_id: int) -> str:
|
||||
"""
|
||||
Reconstructs fully qualified behavior name from name and team_id
|
||||
:param name: brain name
|
||||
:param team_id: team ID
|
||||
:return: name_behavior_id
|
||||
"""
|
||||
Reconstructs fully qualified behavior name from name and team_id
|
||||
:param name: brain name
|
||||
:param team_id: team ID
|
||||
:return: name_behavior_id
|
||||
"""
|
||||
return name + "?team=" + str(team_id)
|
||||
|
||||
|
||||
|
|
|
@ -264,9 +264,7 @@ class AgentBuffer(MutableMapping):
|
|||
)
|
||||
|
||||
def __str__(self):
|
||||
return ", ".join(
|
||||
["'{}' : {}".format(k, str(self[k])) for k in self._fields.keys()]
|
||||
)
|
||||
return ", ".join([f"'{k}' : {str(self[k])}" for k in self._fields.keys()])
|
||||
|
||||
def reset_agent(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -165,7 +165,10 @@ class EnvironmentParameterManager:
|
|||
):
|
||||
behavior_to_consider = lesson.completion_criteria.behavior
|
||||
if behavior_to_consider in trainer_steps:
|
||||
must_increment, new_smoothing = lesson.completion_criteria.need_increment(
|
||||
(
|
||||
must_increment,
|
||||
new_smoothing,
|
||||
) = lesson.completion_criteria.need_increment(
|
||||
float(trainer_steps[behavior_to_consider])
|
||||
/ float(trainer_max_steps[behavior_to_consider]),
|
||||
trainer_reward_buffer[behavior_to_consider],
|
||||
|
|
|
@ -33,7 +33,7 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
|
|||
[
|
||||
"\t"
|
||||
+ " " * num_tabs
|
||||
+ "{}:\t{}".format(x, _dict_to_str(param_dict[x], num_tabs + 1))
|
||||
+ f"{x}:\t{_dict_to_str(param_dict[x], num_tabs + 1)}"
|
||||
for x in param_dict
|
||||
]
|
||||
)
|
||||
|
|
|
@ -71,7 +71,7 @@ def trainer_controller_with_start_learning_mocks(basic_trainer_controller):
|
|||
|
||||
|
||||
def test_start_learning_trains_forever_if_no_train_model(
|
||||
trainer_controller_with_start_learning_mocks
|
||||
trainer_controller_with_start_learning_mocks,
|
||||
):
|
||||
tc, trainer_mock = trainer_controller_with_start_learning_mocks
|
||||
tc.train_model = False
|
||||
|
@ -88,7 +88,7 @@ def test_start_learning_trains_forever_if_no_train_model(
|
|||
|
||||
|
||||
def test_start_learning_trains_until_max_steps_then_saves(
|
||||
trainer_controller_with_start_learning_mocks
|
||||
trainer_controller_with_start_learning_mocks,
|
||||
):
|
||||
tc, trainer_mock = trainer_controller_with_start_learning_mocks
|
||||
|
||||
|
@ -120,7 +120,7 @@ def trainer_controller_with_take_step_mocks(basic_trainer_controller):
|
|||
|
||||
|
||||
def test_advance_adds_experiences_to_trainer_and_trains(
|
||||
trainer_controller_with_take_step_mocks
|
||||
trainer_controller_with_take_step_mocks,
|
||||
):
|
||||
tc, trainer_mock = trainer_controller_with_take_step_mocks
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from mlagents_envs.base_env import ActionSpec
|
|||
|
||||
|
||||
def create_action_model(inp_size, act_size, deterministic=False):
|
||||
mask = torch.ones([1, act_size ** 2])
|
||||
mask = torch.ones([1, act_size**2])
|
||||
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
|
||||
action_model = ActionModel(inp_size, action_spec, deterministic=deterministic)
|
||||
return action_model, mask
|
||||
|
|
|
@ -90,7 +90,10 @@ def test_all_masking(mask_value):
|
|||
# We make sure that a mask of all zeros or all ones will not trigger an error
|
||||
np.random.seed(1336)
|
||||
torch.manual_seed(1336)
|
||||
size, n_k, = 3, 5
|
||||
size, n_k, = (
|
||||
3,
|
||||
5,
|
||||
)
|
||||
embedding_size = 64
|
||||
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
|
||||
entity_embeddings.add_self_embedding(size)
|
||||
|
@ -134,7 +137,10 @@ def test_all_masking(mask_value):
|
|||
def test_predict_closest_training():
|
||||
np.random.seed(1336)
|
||||
torch.manual_seed(1336)
|
||||
size, n_k, = 3, 5
|
||||
size, n_k, = (
|
||||
3,
|
||||
5,
|
||||
)
|
||||
embedding_size = 64
|
||||
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
|
||||
entity_embeddings.add_self_embedding(size)
|
||||
|
|
|
@ -138,7 +138,7 @@ def test_sample_actions(rnn, visual, discrete):
|
|||
|
||||
def test_step_overflow():
|
||||
policy = create_policy_mock(TrainerSettings())
|
||||
policy.set_step(2 ** 31 - 1)
|
||||
assert policy.get_current_step() == 2 ** 31 - 1 # step = 2147483647
|
||||
policy.set_step(2**31 - 1)
|
||||
assert policy.get_current_step() == 2**31 - 1 # step = 2147483647
|
||||
policy.increment_step(3)
|
||||
assert policy.get_current_step() == 2 ** 31 + 2 # step = 2147483650
|
||||
assert policy.get_current_step() == 2**31 + 2 # step = 2147483650
|
||||
|
|
|
@ -39,7 +39,7 @@ def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
|
|||
|
||||
# Generate the masking tensors for each entities tensor (mask only if all zeros)
|
||||
key_masks: List[torch.Tensor] = [
|
||||
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in entities
|
||||
(torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
|
||||
]
|
||||
return key_masks
|
||||
|
||||
|
@ -101,11 +101,11 @@ class MultiHeadAttention(torch.nn.Module):
|
|||
qk = torch.matmul(query, key) # (b, h, n_q, n_k)
|
||||
|
||||
if key_mask is None:
|
||||
qk = qk / (self.embedding_size ** 0.5)
|
||||
qk = qk / (self.embedding_size**0.5)
|
||||
else:
|
||||
key_mask = key_mask.reshape(b, 1, 1, n_k)
|
||||
qk = (1 - key_mask) * qk / (
|
||||
self.embedding_size ** 0.5
|
||||
self.embedding_size**0.5
|
||||
) + key_mask * self.NEG_INF
|
||||
|
||||
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)
|
||||
|
|
|
@ -33,7 +33,9 @@ class BCModule:
|
|||
self._anneal_steps = settings.steps
|
||||
self.current_lr = policy_learning_rate * settings.strength
|
||||
|
||||
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
|
||||
learning_rate_schedule: ScheduleType = (
|
||||
ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
|
||||
)
|
||||
self.decay_learning_rate = ModelUtils.DecayedValue(
|
||||
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
|
||||
)
|
||||
|
|
|
@ -183,10 +183,10 @@ class DiscriminatorNetwork(torch.nn.Module):
|
|||
kl_loss = torch.mean(
|
||||
-torch.sum(
|
||||
1
|
||||
+ (self._z_sigma ** 2).log()
|
||||
- 0.5 * expert_mu ** 2
|
||||
- 0.5 * policy_mu ** 2
|
||||
- (self._z_sigma ** 2),
|
||||
+ (self._z_sigma**2).log()
|
||||
- 0.5 * expert_mu**2
|
||||
- 0.5 * policy_mu**2
|
||||
- (self._z_sigma**2),
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
|
@ -255,6 +255,6 @@ class DiscriminatorNetwork(torch.nn.Module):
|
|||
estimate = self._estimator(hidden).squeeze(1).sum()
|
||||
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0]
|
||||
# Norm's gradient could be NaN at 0. Use our own safe_norm
|
||||
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
|
||||
safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
|
||||
gradient_mag = torch.mean((safe_norm - 1) ** 2)
|
||||
return gradient_mag
|
||||
|
|
|
@ -70,7 +70,7 @@ class GaussianDistInstance(DistInstance):
|
|||
return self.mean
|
||||
|
||||
def log_prob(self, value):
|
||||
var = self.std ** 2
|
||||
var = self.std**2
|
||||
log_scale = torch.log(self.std + EPSILON)
|
||||
return (
|
||||
-((value - self.mean) ** 2) / (2 * var + EPSILON)
|
||||
|
@ -84,7 +84,7 @@ class GaussianDistInstance(DistInstance):
|
|||
|
||||
def entropy(self):
|
||||
return torch.mean(
|
||||
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 + EPSILON),
|
||||
0.5 * torch.log(2 * math.pi * math.e * self.std**2 + EPSILON),
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
) # Use equivalent behavior to TF
|
||||
|
|
|
@ -137,7 +137,7 @@ class RLTrainer(Trainer):
|
|||
return model_saver
|
||||
|
||||
def _policy_mean_reward(self) -> Optional[float]:
|
||||
""" Returns the mean episode reward for the current policy. """
|
||||
"""Returns the mean episode reward for the current policy."""
|
||||
rewards = self.cumulative_returns_since_policy_update
|
||||
if len(rewards) == 0:
|
||||
return None
|
||||
|
|
|
@ -89,5 +89,6 @@ setup(
|
|||
"default=mlagents.plugins.stats_writer:get_default_stats_writers"
|
||||
],
|
||||
},
|
||||
cmdclass={"verify": VerifyVersionCommand},
|
||||
# TODO: Remove this once mypy stops having spurious setuptools issues.
|
||||
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
|
||||
)
|
||||
|
|
|
@ -47,9 +47,7 @@ def test_run_environment(env_name):
|
|||
print("Is there a visual observation ?", vis_obs)
|
||||
|
||||
# Examine the state space for the first observation for the first agent
|
||||
print(
|
||||
"First Agent observation looks like: \n{}".format(decision_steps.obs[0][0])
|
||||
)
|
||||
print(f"First Agent observation looks like: \n{decision_steps.obs[0][0]}")
|
||||
|
||||
for _episode in range(10):
|
||||
env.reset()
|
||||
|
|
|
@ -24,7 +24,7 @@ def get_base_path():
|
|||
|
||||
|
||||
def get_base_output_path():
|
||||
""""
|
||||
""" "
|
||||
Returns the artifact folder to use for yamato jobs.
|
||||
"""
|
||||
return os.path.join(get_base_path(), "artifacts")
|
||||
|
|
|
@ -188,7 +188,7 @@ def check_file(
|
|||
new_file.write(line)
|
||||
else:
|
||||
bad_lines.append(f"{filename}: {line}")
|
||||
new_line = re.sub(r"release_[0-9]+", fr"{release_tag}", line)
|
||||
new_line = re.sub(r"release_[0-9]+", rf"{release_tag}", line)
|
||||
new_line = update_pip_install_line(new_line, package_version)
|
||||
new_file.write(new_line)
|
||||
if bad_lines:
|
||||
|
@ -235,7 +235,7 @@ def main():
|
|||
print(f"Python package version: {package_version}")
|
||||
release_allow_pattern = re.compile(f"{release_tag}(_docs)?")
|
||||
pip_allow_pattern = re.compile(
|
||||
fr"python -m pip install (-q )?mlagents(_envs)?=={package_version}"
|
||||
rf"python -m pip install (-q )?mlagents(_envs)?=={package_version}"
|
||||
)
|
||||
bad_lines = check_all_files(
|
||||
release_allow_pattern, release_tag, pip_allow_pattern, package_version
|
||||
|
|
Загрузка…
Ссылка в новой задаче