add pyupgrade to pre-commit and run (#4239)

This commit is contained in:
Chris Elion 2020-07-16 18:00:56 -07:00 коммит произвёл GitHub
Родитель 6dc68df348
Коммит 728d4927a8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
30 изменённых файлов: 49 добавлений и 45 удалений

Просмотреть файл

@ -39,6 +39,13 @@ repos:
# flake8-tidy-imports is used for banned-modules, not actually tidying
additional_dependencies: [flake8-comprehensions==3.2.2, flake8-tidy-imports==4.1.0, flake8-bugbear==20.1.4]
- repo: https://github.com/asottile/pyupgrade
rev: v2.7.0
hooks:
- id: pyupgrade
args: [--py3-plus]
exclude: .*barracuda.py
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
hooks:

Просмотреть файл

@ -23,7 +23,7 @@ class VerifyVersionCommand(install):
tag = os.getenv("CIRCLE_TAG")
if tag != EXPECTED_TAG:
info = "Git tag: {0} does not match the expected tag of this app: {1}".format(
info = "Git tag: {} does not match the expected tag of this app: {}".format(
tag, EXPECTED_TAG
)
sys.exit(info)

Просмотреть файл

@ -3,7 +3,7 @@ from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
class Communicator(object):
class Communicator:
def __init__(self, worker_id=0, base_port=5005):
"""
Python side of the communication. Must be used in pair with the right Unity Communicator equivalent.

Просмотреть файл

@ -4,7 +4,7 @@ import grpc
from mlagents_envs.communicator_objects import unity_message_pb2 as mlagents__envs_dot_communicator__objects_dot_unity__message__pb2
class UnityToExternalProtoStub(object):
class UnityToExternalProtoStub:
# missing associated documentation comment in .proto file
pass
@ -21,7 +21,7 @@ class UnityToExternalProtoStub(object):
)
class UnityToExternalProtoServicer(object):
class UnityToExternalProtoServicer:
# missing associated documentation comment in .proto file
pass

Просмотреть файл

@ -326,8 +326,8 @@ class UnityEnvironment(BaseEnv):
def _assert_behavior_exists(self, behavior_name: str) -> None:
if behavior_name not in self._env_specs:
raise UnityActionException(
"The group {0} does not correspond to an existing agent group "
"in the environment".format(behavior_name)
f"The group {behavior_name} does not correspond to an existing "
f"agent group in the environment"
)
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:
@ -339,9 +339,9 @@ class UnityEnvironment(BaseEnv):
expected_shape = (len(self._env_state[behavior_name][0]), spec.action_size)
if action.shape != expected_shape:
raise UnityActionException(
"The behavior {0} needs an input of dimension {1} for "
"(<number of agents>, <action size>) but received input of "
"dimension {2}".format(behavior_name, expected_shape, action.shape)
f"The behavior {behavior_name} needs an input of dimension "
f"{expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {action.shape}"
)
if action.dtype != expected_type:
action = action.astype(expected_type)
@ -357,10 +357,9 @@ class UnityEnvironment(BaseEnv):
expected_shape = (spec.action_size,)
if action.shape != expected_shape:
raise UnityActionException(
f"The Agent {0} with BehaviorName {1} needs an input of dimension "
f"{2} but received input of dimension {3}".format(
agent_id, behavior_name, expected_shape, action.shape
)
f"The Agent {agent_id} with BehaviorName {behavior_name} needs "
f"an input of dimension {expected_shape} but received input of "
f"dimension {action.shape}"
)
expected_type = np.float32 if spec.is_action_continuous() else np.int32
if action.dtype != expected_type:

Просмотреть файл

@ -75,4 +75,4 @@ class UnityWorkerInUseException(UnityException):
def __init__(self, worker_id):
message = self.MESSAGE_TEMPLATE.format(str(worker_id))
super(UnityWorkerInUseException, self).__init__(message)
super().__init__(message)

Просмотреть файл

@ -81,7 +81,7 @@ class RpcCommunicator(Communicator):
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(("localhost", port))
except socket.error:
except OSError:
raise UnityWorkerInUseException(self.worker_id)
finally:
s.close()

Просмотреть файл

@ -22,7 +22,7 @@ class EnvironmentParametersChannel(SideChannel):
MULTIRANGEUNIFORM = 2
def __init__(self) -> None:
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400"))
channel_id = uuid.UUID("534c891e-810f-11ea-a9d0-822485860400")
super().__init__(channel_id)
def on_message_received(self, msg: IncomingMessage) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ class FloatPropertiesChannel(SideChannel):
def __init__(self, channel_id: uuid.UUID = None) -> None:
self._float_properties: Dict[str, float] = {}
if channel_id is None:
channel_id = uuid.UUID(("60ccf7d0-4f7e-11ea-b238-784f4387d1f7"))
channel_id = uuid.UUID("60ccf7d0-4f7e-11ea-b238-784f4387d1f7")
super().__init__(channel_id)
def on_message_received(self, msg: IncomingMessage) -> None:

Просмотреть файл

@ -33,7 +33,7 @@ class SideChannelManager:
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"The message received by the side channel {} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_id)
)

Просмотреть файл

@ -84,8 +84,8 @@ def test_raw_bytes():
sender = RawBytesChannel(guid)
receiver = RawBytesChannel(guid)
sender.send_raw_data("foo".encode("ascii"))
sender.send_raw_data("bar".encode("ascii"))
sender.send_raw_data(b"foo")
sender.send_raw_data(b"bar")
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)

Просмотреть файл

@ -23,7 +23,7 @@ class VerifyVersionCommand(install):
tag = os.getenv("CIRCLE_TAG")
if tag != EXPECTED_TAG:
info = "Git tag: {0} does not match the expected tag of this app: {1}".format(
info = "Git tag: {} does not match the expected tag of this app: {}".format(
tag, EXPECTED_TAG
)
sys.exit(info)

Просмотреть файл

@ -134,7 +134,7 @@ class AgentBuffer(dict):
super().__init__()
def __str__(self):
return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()])
return ", ".join(["'{}' : {}".format(k, str(self[k])) for k in self.keys()])
def reset_agent(self) -> None:
"""
@ -275,7 +275,7 @@ class AgentBuffer(dict):
key_list = list(self.keys())
if not self.check_length(key_list):
raise BufferException(
"The length of the fields {0} were not of same length".format(key_list)
"The length of the fields {} were not of same length".format(key_list)
)
for field_key in key_list:
target_buffer[field_key].extend(

Просмотреть файл

@ -232,7 +232,7 @@ def load_config(config_path: str) -> Dict[str, Any]:
try:
with open(config_path) as data_file:
return _load_config(data_file)
except IOError:
except OSError:
abs_path = os.path.abspath(config_path)
raise TrainerConfigError(f"Config file could not be found at {abs_path}.")
except UnicodeDecodeError:

Просмотреть файл

@ -3,7 +3,7 @@ from mlagents.tf_utils import tf
from mlagents.trainers.policy.tf_policy import TFPolicy
class BCModel(object):
class BCModel:
def __init__(
self, policy: TFPolicy, learning_rate: float = 3e-4, anneal_steps: int = 0
):

Просмотреть файл

@ -5,7 +5,7 @@ from mlagents.trainers.models import ModelUtils
from mlagents.trainers.policy.tf_policy import TFPolicy
class CuriosityModel(object):
class CuriosityModel:
def __init__(
self, policy: TFPolicy, encoding_size: int = 128, learning_rate: float = 3e-4
):

Просмотреть файл

@ -8,7 +8,7 @@ from mlagents.trainers.models import ModelUtils
EPSILON = 1e-7
class GAILModel(object):
class GAILModel:
def __init__(
self,
policy: TFPolicy,

Просмотреть файл

@ -31,7 +31,7 @@ def create_reward_signal(
"""
rcls = NAME_TO_CLASS.get(name)
if not rcls:
raise UnityTrainerException("Unknown reward signal type {0}".format(name))
raise UnityTrainerException("Unknown reward signal type {}".format(name))
class_inst = rcls(policy, settings)
return class_inst

Просмотреть файл

@ -58,7 +58,7 @@ class GhostTrainer(Trainer):
:param artifact_path: Path to store artifacts from this trainer.
"""
super(GhostTrainer, self).__init__(
super().__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)

Просмотреть файл

@ -158,7 +158,7 @@ class TFPolicy(Policy):
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {0} could not be loaded. Make "
"The model {} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
@ -167,7 +167,7 @@ class TFPolicy(Policy):
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {0} was found but could not be loaded. Make "
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path

Просмотреть файл

@ -44,7 +44,7 @@ class PPOTrainer(RLTrainer):
:param seed: The seed the model will be initialized with
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super(PPOTrainer, self).__init__(
super().__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.hyperparameters: PPOSettings = cast(

Просмотреть файл

@ -127,7 +127,7 @@ class ConsoleWriter(StatsWriter):
) -> None:
if property_type == StatsPropertyType.HYPERPARAMETERS:
logger.info(
"""Hyperparameters for behavior name {0}: \n{1}""".format(
"""Hyperparameters for behavior name {}: \n{}""".format(
category, self._dict_to_str(value, 0)
)
)
@ -150,7 +150,7 @@ class ConsoleWriter(StatsWriter):
[
"\t"
+ " " * num_tabs
+ "{0}:\t{1}".format(
+ "{}:\t{}".format(
x, self._dict_to_str(param_dict[x], num_tabs + 1)
)
for x in param_dict
@ -226,7 +226,7 @@ class TensorboardWriter(StatsWriter):
s_op = tf.summary.text(
name,
tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])
[[str(x), str(input_dict[x])] for x in input_dict]
),
)
s = sess.run(s_op)

Просмотреть файл

@ -20,7 +20,7 @@ def test_globaltrainingstatus(tmpdir):
GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3)
GlobalTrainingStatus.save_state(path_dir)
with open(path_dir, "r") as fp:
with open(path_dir) as fp:
test_json = json.load(fp)
assert "Category1" in test_json

Просмотреть файл

@ -32,7 +32,7 @@ class RLTrainer(Trainer): # pylint: disable=abstract-method
"""
def __init__(self, *args, **kwargs):
super(RLTrainer, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.

Просмотреть файл

@ -30,7 +30,7 @@ from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
class TrainerController(object):
class TrainerController:
def __init__(
self,
trainer_factory: TrainerFactory,

Просмотреть файл

@ -67,7 +67,7 @@ class GlobalTrainingStatus:
:param path: Path to the JSON file containing the state.
"""
try:
with open(path, "r") as f:
with open(path) as f:
loaded_dict = json.load(f)
# Compare the metadata
_metadata = loaded_dict[StatusType.STATS_METADATA.value]

Просмотреть файл

@ -1,4 +1,3 @@
from io import open
import os
import sys
@ -25,7 +24,7 @@ class VerifyVersionCommand(install):
tag = os.getenv("CIRCLE_TAG")
if tag != EXPECTED_TAG:
info = "Git tag: {0} does not match the expected tag of this app: {1}".format(
info = "Git tag: {} does not match the expected tag of this app: {}".format(
tag, EXPECTED_TAG
)
sys.exit(info)

Просмотреть файл

@ -1,4 +1,3 @@
from __future__ import print_function
import sys
import os

Просмотреть файл

@ -77,7 +77,7 @@ def check_file(filename: str, global_allow_pattern: Pattern) -> List[str]:
Validate a single file and return any offending lines.
"""
bad_lines = []
with open(filename, "r") as f:
with open(filename) as f:
for line in f:
if not RELEASE_PATTERN.search(line):
continue

Просмотреть файл

@ -94,7 +94,7 @@ def set_version(
def set_package_version(new_version: str) -> None:
with open(MLAGENTS_PACKAGE_JSON_PATH, "r") as f:
with open(MLAGENTS_PACKAGE_JSON_PATH) as f:
package_json = json.load(f)
if "version" in package_json:
package_json["version"] = new_version
@ -104,7 +104,7 @@ def set_package_version(new_version: str) -> None:
def set_extension_package_version(new_version: str) -> None:
with open(MLAGENTS_EXTENSIONS_PACKAGE_JSON_PATH, "r") as f:
with open(MLAGENTS_EXTENSIONS_PACKAGE_JSON_PATH) as f:
package_json = json.load(f)
package_json["dependencies"]["com.unity.ml-agents"] = new_version
with open(MLAGENTS_EXTENSIONS_PACKAGE_JSON_PATH, "w") as f: