add pyupgrade to pre-commit and run (#4239)
This commit is contained in:
Родитель
6dc68df348
Коммит
728d4927a8
|
@ -39,6 +39,13 @@ repos:
|
|||
# flake8-tidy-imports is used for banned-modules, not actually tidying
|
||||
additional_dependencies: [flake8-comprehensions==3.2.2, flake8-tidy-imports==4.1.0, flake8-bugbear==20.1.4]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.7.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py3-plus]
|
||||
exclude: .*barracuda.py
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.5.0
|
||||
hooks:
|
||||
|
|
|
@ -23,7 +23,7 @@ class VerifyVersionCommand(install):
|
|||
tag = os.getenv("CIRCLE_TAG")
|
||||
|
||||
if tag != EXPECTED_TAG:
|
||||
info = "Git tag: {0} does not match the expected tag of this app: {1}".format(
|
||||
info = "Git tag: {} does not match the expected tag of this app: {}".format(
|
||||
tag, EXPECTED_TAG
|
||||
)
|
||||
sys.exit(info)
|
||||
|
|
|
@ -3,7 +3,7 @@ from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto
|
|||
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
|
||||
|
||||
|
||||
class Communicator(object):
|
||||
class Communicator:
|
||||
def __init__(self, worker_id=0, base_port=5005):
|
||||
"""
|
||||
Python side of the communication. Must be used in pair with the right Unity Communicator equivalent.
|
||||
|
|
|
@ -4,7 +4,7 @@ import grpc
|
|||
from mlagents_envs.communicator_objects import unity_message_pb2 as mlagents__envs_dot_communicator__objects_dot_unity__message__pb2
|
||||
|
||||
|
||||
class UnityToExternalProtoStub(object):
|
||||
class UnityToExternalProtoStub:
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
|
@ -21,7 +21,7 @@ class UnityToExternalProtoStub(object):
|
|||
)
|
||||
|
||||
|
||||
class UnityToExternalProtoServicer(object):
|
||||
class UnityToExternalProtoServicer:
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
|
|
|
@ -326,8 +326,8 @@ class UnityEnvironment(BaseEnv):
|
|||
def _assert_behavior_exists(self, behavior_name: str) -> None:
|
||||
if behavior_name not in self._env_specs:
|
||||
raise UnityActionException(
|
||||
"The group {0} does not correspond to an existing agent group "
|
||||
"in the environment".format(behavior_name)
|
||||
f"The group {behavior_name} does not correspond to an existing "
|
||||
f"agent group in the environment"
|
||||
)
|
||||
|
||||
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:
|
||||
|
@ -339,9 +339,9 @@ class UnityEnvironment(BaseEnv):
|
|||
expected_shape = (len(self._env_state[behavior_name][0]), spec.action_size)
|
||||
if action.shape != expected_shape:
|
||||
raise UnityActionException(
|
||||
"The behavior {0} needs an input of dimension {1} for "
|
||||
"(<number of agents>, <action size>) but received input of "
|
||||
"dimension {2}".format(behavior_name, expected_shape, action.shape)
|
||||
f"The behavior {behavior_name} needs an input of dimension "
|
||||
f"{expected_shape} for (<number of agents>, <action size>) but "
|
||||
f"received input of dimension {action.shape}"
|
||||
)
|
||||
if action.dtype != expected_type:
|
||||
action = action.astype(expected_type)
|
||||
|
@ -357,10 +357,9 @@ class UnityEnvironment(BaseEnv):
|
|||
expected_shape = (spec.action_size,)
|
||||
if action.shape != expected_shape:
|
||||
raise UnityActionException(
|
||||
f"The Agent {0} with BehaviorName {1} needs an input of dimension "
|
||||
f"{2} but received input of dimension {3}".format(
|
||||
agent_id, behavior_name, expected_shape, action.shape
|
||||
)
|
||||
f"The Agent {agent_id} with BehaviorName {behavior_name} needs "
|
||||
f"an input of dimension {expected_shape} but received input of "
|
||||
f"dimension {action.shape}"
|
||||
)
|
||||
expected_type = np.float32 if spec.is_action_continuous() else np.int32
|
||||
if action.dtype != expected_type:
|
||||
|
|
|
@ -75,4 +75,4 @@ class UnityWorkerInUseException(UnityException):
|
|||
|
||||
def __init__(self, worker_id):
|
||||
message = self.MESSAGE_TEMPLATE.format(str(worker_id))
|
||||
super(UnityWorkerInUseException, self).__init__(message)
|
||||
super().__init__(message)
|
||||
|
|
|
@ -81,7 +81,7 @@ class RpcCommunicator(Communicator):
|
|||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
s.bind(("localhost", port))
|
||||
except socket.error:
|
||||
except OSError:
|
||||
raise UnityWorkerInUseException(self.worker_id)
|
||||
finally:
|
||||
s.close()
|
||||
|
|
|
@ -22,7 +22,7 @@ class EnvironmentParametersChannel(SideChannel):
|
|||
MULTIRANGEUNIFORM = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
channel_id = uuid.UUID(("534c891e-810f-11ea-a9d0-822485860400"))
|
||||
channel_id = uuid.UUID("534c891e-810f-11ea-a9d0-822485860400")
|
||||
super().__init__(channel_id)
|
||||
|
||||
def on_message_received(self, msg: IncomingMessage) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ class FloatPropertiesChannel(SideChannel):
|
|||
def __init__(self, channel_id: uuid.UUID = None) -> None:
|
||||
self._float_properties: Dict[str, float] = {}
|
||||
if channel_id is None:
|
||||
channel_id = uuid.UUID(("60ccf7d0-4f7e-11ea-b238-784f4387d1f7"))
|
||||
channel_id = uuid.UUID("60ccf7d0-4f7e-11ea-b238-784f4387d1f7")
|
||||
super().__init__(channel_id)
|
||||
|
||||
def on_message_received(self, msg: IncomingMessage) -> None:
|
||||
|
|
|
@ -33,7 +33,7 @@ class SideChannelManager:
|
|||
)
|
||||
if len(message_data) != message_len:
|
||||
raise UnityEnvironmentException(
|
||||
"The message received by the side channel {0} was "
|
||||
"The message received by the side channel {} was "
|
||||
"unexpectedly short. Make sure your Unity Environment "
|
||||
"sending side channel data properly.".format(channel_id)
|
||||
)
|
||||
|
|
|
@ -84,8 +84,8 @@ def test_raw_bytes():
|
|||
sender = RawBytesChannel(guid)
|
||||
receiver = RawBytesChannel(guid)
|
||||
|
||||
sender.send_raw_data("foo".encode("ascii"))
|
||||
sender.send_raw_data("bar".encode("ascii"))
|
||||
sender.send_raw_data(b"foo")
|
||||
sender.send_raw_data(b"bar")
|
||||
|
||||
data = SideChannelManager([sender]).generate_side_channel_messages()
|
||||
SideChannelManager([receiver]).process_side_channel_message(data)
|
||||
|
|
|
@ -23,7 +23,7 @@ class VerifyVersionCommand(install):
|
|||
tag = os.getenv("CIRCLE_TAG")
|
||||
|
||||
if tag != EXPECTED_TAG:
|
||||
info = "Git tag: {0} does not match the expected tag of this app: {1}".format(
|
||||
info = "Git tag: {} does not match the expected tag of this app: {}".format(
|
||||
tag, EXPECTED_TAG
|
||||
)
|
||||
sys.exit(info)
|
||||
|
|
|
@ -134,7 +134,7 @@ class AgentBuffer(dict):
|
|||
super().__init__()
|
||||
|
||||
def __str__(self):
|
||||
return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()])
|
||||
return ", ".join(["'{}' : {}".format(k, str(self[k])) for k in self.keys()])
|
||||
|
||||
def reset_agent(self) -> None:
|
||||
"""
|
||||
|
@ -275,7 +275,7 @@ class AgentBuffer(dict):
|
|||
key_list = list(self.keys())
|
||||
if not self.check_length(key_list):
|
||||
raise BufferException(
|
||||
"The length of the fields {0} were not of same length".format(key_list)
|
||||
"The length of the fields {} were not of same length".format(key_list)
|
||||
)
|
||||
for field_key in key_list:
|
||||
target_buffer[field_key].extend(
|
||||
|
|
|
@ -232,7 +232,7 @@ def load_config(config_path: str) -> Dict[str, Any]:
|
|||
try:
|
||||
with open(config_path) as data_file:
|
||||
return _load_config(data_file)
|
||||
except IOError:
|
||||
except OSError:
|
||||
abs_path = os.path.abspath(config_path)
|
||||
raise TrainerConfigError(f"Config file could not be found at {abs_path}.")
|
||||
except UnicodeDecodeError:
|
||||
|
|
|
@ -3,7 +3,7 @@ from mlagents.tf_utils import tf
|
|||
from mlagents.trainers.policy.tf_policy import TFPolicy
|
||||
|
||||
|
||||
class BCModel(object):
|
||||
class BCModel:
|
||||
def __init__(
|
||||
self, policy: TFPolicy, learning_rate: float = 3e-4, anneal_steps: int = 0
|
||||
):
|
||||
|
|
|
@ -5,7 +5,7 @@ from mlagents.trainers.models import ModelUtils
|
|||
from mlagents.trainers.policy.tf_policy import TFPolicy
|
||||
|
||||
|
||||
class CuriosityModel(object):
|
||||
class CuriosityModel:
|
||||
def __init__(
|
||||
self, policy: TFPolicy, encoding_size: int = 128, learning_rate: float = 3e-4
|
||||
):
|
||||
|
|
|
@ -8,7 +8,7 @@ from mlagents.trainers.models import ModelUtils
|
|||
EPSILON = 1e-7
|
||||
|
||||
|
||||
class GAILModel(object):
|
||||
class GAILModel:
|
||||
def __init__(
|
||||
self,
|
||||
policy: TFPolicy,
|
||||
|
|
|
@ -31,7 +31,7 @@ def create_reward_signal(
|
|||
"""
|
||||
rcls = NAME_TO_CLASS.get(name)
|
||||
if not rcls:
|
||||
raise UnityTrainerException("Unknown reward signal type {0}".format(name))
|
||||
raise UnityTrainerException("Unknown reward signal type {}".format(name))
|
||||
|
||||
class_inst = rcls(policy, settings)
|
||||
return class_inst
|
||||
|
|
|
@ -58,7 +58,7 @@ class GhostTrainer(Trainer):
|
|||
:param artifact_path: Path to store artifacts from this trainer.
|
||||
"""
|
||||
|
||||
super(GhostTrainer, self).__init__(
|
||||
super().__init__(
|
||||
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
|
||||
)
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ class TFPolicy(Policy):
|
|||
ckpt = tf.train.get_checkpoint_state(model_path)
|
||||
if ckpt is None:
|
||||
raise UnityPolicyException(
|
||||
"The model {0} could not be loaded. Make "
|
||||
"The model {} could not be loaded. Make "
|
||||
"sure you specified the right "
|
||||
"--run-id and that the previous run you are loading from had the same "
|
||||
"behavior names.".format(model_path)
|
||||
|
@ -167,7 +167,7 @@ class TFPolicy(Policy):
|
|||
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
|
||||
except tf.errors.NotFoundError:
|
||||
raise UnityPolicyException(
|
||||
"The model {0} was found but could not be loaded. Make "
|
||||
"The model {} was found but could not be loaded. Make "
|
||||
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
|
||||
"and is using the same trainer configuration as the current run.".format(
|
||||
model_path
|
||||
|
|
|
@ -44,7 +44,7 @@ class PPOTrainer(RLTrainer):
|
|||
:param seed: The seed the model will be initialized with
|
||||
:param artifact_path: The directory within which to store artifacts from this trainer.
|
||||
"""
|
||||
super(PPOTrainer, self).__init__(
|
||||
super().__init__(
|
||||
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
|
||||
)
|
||||
self.hyperparameters: PPOSettings = cast(
|
||||
|
|
|
@ -127,7 +127,7 @@ class ConsoleWriter(StatsWriter):
|
|||
) -> None:
|
||||
if property_type == StatsPropertyType.HYPERPARAMETERS:
|
||||
logger.info(
|
||||
"""Hyperparameters for behavior name {0}: \n{1}""".format(
|
||||
"""Hyperparameters for behavior name {}: \n{}""".format(
|
||||
category, self._dict_to_str(value, 0)
|
||||
)
|
||||
)
|
||||
|
@ -150,7 +150,7 @@ class ConsoleWriter(StatsWriter):
|
|||
[
|
||||
"\t"
|
||||
+ " " * num_tabs
|
||||
+ "{0}:\t{1}".format(
|
||||
+ "{}:\t{}".format(
|
||||
x, self._dict_to_str(param_dict[x], num_tabs + 1)
|
||||
)
|
||||
for x in param_dict
|
||||
|
@ -226,7 +226,7 @@ class TensorboardWriter(StatsWriter):
|
|||
s_op = tf.summary.text(
|
||||
name,
|
||||
tf.convert_to_tensor(
|
||||
([[str(x), str(input_dict[x])] for x in input_dict])
|
||||
[[str(x), str(input_dict[x])] for x in input_dict]
|
||||
),
|
||||
)
|
||||
s = sess.run(s_op)
|
||||
|
|
|
@ -20,7 +20,7 @@ def test_globaltrainingstatus(tmpdir):
|
|||
GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3)
|
||||
GlobalTrainingStatus.save_state(path_dir)
|
||||
|
||||
with open(path_dir, "r") as fp:
|
||||
with open(path_dir) as fp:
|
||||
test_json = json.load(fp)
|
||||
|
||||
assert "Category1" in test_json
|
||||
|
|
|
@ -32,7 +32,7 @@ class RLTrainer(Trainer): # pylint: disable=abstract-method
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RLTrainer, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
|
||||
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
|
||||
# of what reward signals are actually present.
|
||||
|
|
|
@ -30,7 +30,7 @@ from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
|
|||
from mlagents.trainers.agent_processor import AgentManager
|
||||
|
||||
|
||||
class TrainerController(object):
|
||||
class TrainerController:
|
||||
def __init__(
|
||||
self,
|
||||
trainer_factory: TrainerFactory,
|
||||
|
|
|
@ -67,7 +67,7 @@ class GlobalTrainingStatus:
|
|||
:param path: Path to the JSON file containing the state.
|
||||
"""
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
with open(path) as f:
|
||||
loaded_dict = json.load(f)
|
||||
# Compare the metadata
|
||||
_metadata = loaded_dict[StatusType.STATS_METADATA.value]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from io import open
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
@ -25,7 +24,7 @@ class VerifyVersionCommand(install):
|
|||
tag = os.getenv("CIRCLE_TAG")
|
||||
|
||||
if tag != EXPECTED_TAG:
|
||||
info = "Git tag: {0} does not match the expected tag of this app: {1}".format(
|
||||
info = "Git tag: {} does not match the expected tag of this app: {}".format(
|
||||
tag, EXPECTED_TAG
|
||||
)
|
||||
sys.exit(info)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from __future__ import print_function
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ def check_file(filename: str, global_allow_pattern: Pattern) -> List[str]:
|
|||
Validate a single file and return any offending lines.
|
||||
"""
|
||||
bad_lines = []
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
if not RELEASE_PATTERN.search(line):
|
||||
continue
|
||||
|
|
|
@ -94,7 +94,7 @@ def set_version(
|
|||
|
||||
|
||||
def set_package_version(new_version: str) -> None:
|
||||
with open(MLAGENTS_PACKAGE_JSON_PATH, "r") as f:
|
||||
with open(MLAGENTS_PACKAGE_JSON_PATH) as f:
|
||||
package_json = json.load(f)
|
||||
if "version" in package_json:
|
||||
package_json["version"] = new_version
|
||||
|
@ -104,7 +104,7 @@ def set_package_version(new_version: str) -> None:
|
|||
|
||||
|
||||
def set_extension_package_version(new_version: str) -> None:
|
||||
with open(MLAGENTS_EXTENSIONS_PACKAGE_JSON_PATH, "r") as f:
|
||||
with open(MLAGENTS_EXTENSIONS_PACKAGE_JSON_PATH) as f:
|
||||
package_json = json.load(f)
|
||||
package_json["dependencies"]["com.unity.ml-agents"] = new_version
|
||||
with open(MLAGENTS_EXTENSIONS_PACKAGE_JSON_PATH, "w") as f:
|
||||
|
|
Загрузка…
Ссылка в новой задаче