This commit is contained in:
Henry Peteet 2022-02-07 21:18:16 -05:00 коммит произвёл GitHub Enterprise
Родитель 1bde547068
Коммит 9a5a1418ff
22 изменённых файлов: 73 добавлений и 56 удалений

Просмотреть файл

@ -1,30 +1,32 @@
repos:
- repo: https://github.com/python/black
rev: 19.3b0
rev: 22.1.0
hooks:
- id: black
exclude: >
(?x)^(
.*_pb2.py|
.*_pb2.pyi|
.*_pb2_grpc.py
)$
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.761
rev: v0.931
hooks:
- id: mypy
name: mypy-ml-agents
files: "ml-agents/.*"
args: [--ignore-missing-imports, --disallow-incomplete-defs]
args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional]
additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools]
- id: mypy
name: mypy-ml-agents-envs
files: "ml-agents-envs/.*"
# Exclude protobuf files and don't follow them when imported
exclude: ".*_pb2.py"
args: [--ignore-missing-imports, --disallow-incomplete-defs]
args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional]
additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools]
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.1
rev: 3.9.2
hooks:
- id: flake8
exclude: >
@ -36,7 +38,7 @@ repos:
additional_dependencies: [flake8-comprehensions==3.2.2, flake8-tidy-imports==4.1.0, flake8-bugbear==20.1.4]
- repo: https://github.com/asottile/pyupgrade
rev: v2.7.0
rev: v2.31.0
hooks:
- id: pyupgrade
args: [--py3-plus, --py36-plus]
@ -47,7 +49,7 @@ repos:
)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
rev: v4.1.0
hooks:
- id: mixed-line-ending
exclude: >
@ -68,12 +70,12 @@ repos:
exclude: \.yamato/.*
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.4.2
rev: v1.9.0
hooks:
- id: python-check-mock-methods
- repo: https://github.com/mattlqx/pre-commit-search-and-replace
rev: v1.0.3
rev: v1.0.5
hooks:
- id: search-and-replace
types: [markdown]

Просмотреть файл

@ -253,9 +253,15 @@ class UnityPettingzooBaseEnv:
self._current_action[behavior_name] = self._create_empty_actions(
behavior_name, len(current_batch[0])
)
agents, obs, dones, rewards, cumulative_rewards, infos, id_map = _unwrap_batch_steps(
current_batch, behavior_name
)
(
agents,
obs,
dones,
rewards,
cumulative_rewards,
infos,
id_map,
) = _unwrap_batch_steps(current_batch, behavior_name)
self._live_agents += agents
self._agents += agents
self._observations.update(obs)

Просмотреть файл

@ -137,7 +137,7 @@ def download_and_extract_zip(url: str, name: str) -> None:
try:
request = urllib.request.urlopen(url, timeout=30)
except urllib.error.HTTPError as e: # type: ignore
e.msg += " " + url
e.reason = f"{e.reason} {url}"
raise
zip_size = int(request.headers["content-length"])
zip_file_path = os.path.join(zip_dir, str(uuid.uuid4()) + ".zip")
@ -193,7 +193,7 @@ def load_remote_manifest(url: str) -> Dict[str, Any]:
try:
request = urllib.request.urlopen(url, timeout=30)
except urllib.error.HTTPError as e: # type: ignore
e.msg += " " + url
e.reason = f"{e.reason} {url}"
raise
manifest_path = os.path.join(tmp_dir, str(uuid.uuid4()) + ".yaml")
with open(manifest_path, "wb") as manifest:

Просмотреть файл

@ -21,7 +21,7 @@ class SideChannelManager:
try:
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
(message_len,) = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
@ -63,7 +63,7 @@ class SideChannelManager:
@staticmethod
def _get_side_channels_dict(
side_channels: Optional[List[SideChannel]]
side_channels: Optional[List[SideChannel]],
) -> Dict[uuid.UUID, SideChannel]:
"""
Converts a list of side channels into a dictionary of channel_id to SideChannel

Просмотреть файл

@ -59,5 +59,6 @@ setup(
"numpy==1.21.2",
],
python_requires=">=3.7.2,<3.9.10",
cmdclass={"verify": VerifyVersionCommand},
# TODO: Remove this once mypy stops having spurious setuptools issues.
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
)

Просмотреть файл

@ -42,11 +42,11 @@ class BehaviorIdentifiers(NamedTuple):
def create_name_behavior_id(name: str, team_id: int) -> str:
"""
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
"""
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
"""
return name + "?team=" + str(team_id)

Просмотреть файл

@ -264,9 +264,7 @@ class AgentBuffer(MutableMapping):
)
def __str__(self):
return ", ".join(
["'{}' : {}".format(k, str(self[k])) for k in self._fields.keys()]
)
return ", ".join([f"'{k}' : {str(self[k])}" for k in self._fields.keys()])
def reset_agent(self) -> None:
"""

Просмотреть файл

@ -165,7 +165,10 @@ class EnvironmentParameterManager:
):
behavior_to_consider = lesson.completion_criteria.behavior
if behavior_to_consider in trainer_steps:
must_increment, new_smoothing = lesson.completion_criteria.need_increment(
(
must_increment,
new_smoothing,
) = lesson.completion_criteria.need_increment(
float(trainer_steps[behavior_to_consider])
/ float(trainer_max_steps[behavior_to_consider]),
trainer_reward_buffer[behavior_to_consider],

Просмотреть файл

@ -33,7 +33,7 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
[
"\t"
+ " " * num_tabs
+ "{}:\t{}".format(x, _dict_to_str(param_dict[x], num_tabs + 1))
+ f"{x}:\t{_dict_to_str(param_dict[x], num_tabs + 1)}"
for x in param_dict
]
)

Просмотреть файл

@ -71,7 +71,7 @@ def trainer_controller_with_start_learning_mocks(basic_trainer_controller):
def test_start_learning_trains_forever_if_no_train_model(
trainer_controller_with_start_learning_mocks
trainer_controller_with_start_learning_mocks,
):
tc, trainer_mock = trainer_controller_with_start_learning_mocks
tc.train_model = False
@ -88,7 +88,7 @@ def test_start_learning_trains_forever_if_no_train_model(
def test_start_learning_trains_until_max_steps_then_saves(
trainer_controller_with_start_learning_mocks
trainer_controller_with_start_learning_mocks,
):
tc, trainer_mock = trainer_controller_with_start_learning_mocks
@ -120,7 +120,7 @@ def trainer_controller_with_take_step_mocks(basic_trainer_controller):
def test_advance_adds_experiences_to_trainer_and_trains(
trainer_controller_with_take_step_mocks
trainer_controller_with_take_step_mocks,
):
tc, trainer_mock = trainer_controller_with_take_step_mocks

Просмотреть файл

@ -12,7 +12,7 @@ from mlagents_envs.base_env import ActionSpec
def create_action_model(inp_size, act_size, deterministic=False):
mask = torch.ones([1, act_size ** 2])
mask = torch.ones([1, act_size**2])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
action_model = ActionModel(inp_size, action_spec, deterministic=deterministic)
return action_model, mask

Просмотреть файл

@ -90,7 +90,10 @@ def test_all_masking(mask_value):
# We make sure that a mask of all zeros or all ones will not trigger an error
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
size, n_k, = (
3,
5,
)
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
@ -134,7 +137,10 @@ def test_all_masking(mask_value):
def test_predict_closest_training():
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
size, n_k, = (
3,
5,
)
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)

Просмотреть файл

@ -138,7 +138,7 @@ def test_sample_actions(rnn, visual, discrete):
def test_step_overflow():
policy = create_policy_mock(TrainerSettings())
policy.set_step(2 ** 31 - 1)
assert policy.get_current_step() == 2 ** 31 - 1 # step = 2147483647
policy.set_step(2**31 - 1)
assert policy.get_current_step() == 2**31 - 1 # step = 2147483647
policy.increment_step(3)
assert policy.get_current_step() == 2 ** 31 + 2 # step = 2147483650
assert policy.get_current_step() == 2**31 + 2 # step = 2147483650

Просмотреть файл

@ -39,7 +39,7 @@ def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in entities
(torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
]
return key_masks
@ -101,11 +101,11 @@ class MultiHeadAttention(torch.nn.Module):
qk = torch.matmul(query, key) # (b, h, n_q, n_k)
if key_mask is None:
qk = qk / (self.embedding_size ** 0.5)
qk = qk / (self.embedding_size**0.5)
else:
key_mask = key_mask.reshape(b, 1, 1, n_k)
qk = (1 - key_mask) * qk / (
self.embedding_size ** 0.5
self.embedding_size**0.5
) + key_mask * self.NEG_INF
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)

Просмотреть файл

@ -33,7 +33,9 @@ class BCModule:
self._anneal_steps = settings.steps
self.current_lr = policy_learning_rate * settings.strength
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
learning_rate_schedule: ScheduleType = (
ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
)
self.decay_learning_rate = ModelUtils.DecayedValue(
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
)

Просмотреть файл

@ -183,10 +183,10 @@ class DiscriminatorNetwork(torch.nn.Module):
kl_loss = torch.mean(
-torch.sum(
1
+ (self._z_sigma ** 2).log()
- 0.5 * expert_mu ** 2
- 0.5 * policy_mu ** 2
- (self._z_sigma ** 2),
+ (self._z_sigma**2).log()
- 0.5 * expert_mu**2
- 0.5 * policy_mu**2
- (self._z_sigma**2),
dim=1,
)
)
@ -255,6 +255,6 @@ class DiscriminatorNetwork(torch.nn.Module):
estimate = self._estimator(hidden).squeeze(1).sum()
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0]
# Norm's gradient could be NaN at 0. Use our own safe_norm
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
gradient_mag = torch.mean((safe_norm - 1) ** 2)
return gradient_mag

Просмотреть файл

@ -70,7 +70,7 @@ class GaussianDistInstance(DistInstance):
return self.mean
def log_prob(self, value):
var = self.std ** 2
var = self.std**2
log_scale = torch.log(self.std + EPSILON)
return (
-((value - self.mean) ** 2) / (2 * var + EPSILON)
@ -84,7 +84,7 @@ class GaussianDistInstance(DistInstance):
def entropy(self):
return torch.mean(
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 + EPSILON),
0.5 * torch.log(2 * math.pi * math.e * self.std**2 + EPSILON),
dim=1,
keepdim=True,
) # Use equivalent behavior to TF

Просмотреть файл

@ -137,7 +137,7 @@ class RLTrainer(Trainer):
return model_saver
def _policy_mean_reward(self) -> Optional[float]:
""" Returns the mean episode reward for the current policy. """
"""Returns the mean episode reward for the current policy."""
rewards = self.cumulative_returns_since_policy_update
if len(rewards) == 0:
return None

Просмотреть файл

@ -89,5 +89,6 @@ setup(
"default=mlagents.plugins.stats_writer:get_default_stats_writers"
],
},
cmdclass={"verify": VerifyVersionCommand},
# TODO: Remove this once mypy stops having spurious setuptools issues.
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
)

Просмотреть файл

@ -47,9 +47,7 @@ def test_run_environment(env_name):
print("Is there a visual observation ?", vis_obs)
# Examine the state space for the first observation for the first agent
print(
"First Agent observation looks like: \n{}".format(decision_steps.obs[0][0])
)
print(f"First Agent observation looks like: \n{decision_steps.obs[0][0]}")
for _episode in range(10):
env.reset()

Просмотреть файл

@ -24,7 +24,7 @@ def get_base_path():
def get_base_output_path():
""""
""" "
Returns the artifact folder to use for yamato jobs.
"""
return os.path.join(get_base_path(), "artifacts")

Просмотреть файл

@ -188,7 +188,7 @@ def check_file(
new_file.write(line)
else:
bad_lines.append(f"{filename}: {line}")
new_line = re.sub(r"release_[0-9]+", fr"{release_tag}", line)
new_line = re.sub(r"release_[0-9]+", rf"{release_tag}", line)
new_line = update_pip_install_line(new_line, package_version)
new_file.write(new_line)
if bad_lines:
@ -235,7 +235,7 @@ def main():
print(f"Python package version: {package_version}")
release_allow_pattern = re.compile(f"{release_tag}(_docs)?")
pip_allow_pattern = re.compile(
fr"python -m pip install (-q )?mlagents(_envs)?=={package_version}"
rf"python -m pip install (-q )?mlagents(_envs)?=={package_version}"
)
bad_lines = check_all_files(
release_allow_pattern, release_tag, pip_allow_pattern, package_version